Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplifying betweenness_centrality (for vertices) #815

Merged
merged 7 commits into from
Mar 10, 2023
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 38 additions & 109 deletions rustworkx-core/src/centrality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use petgraph::visit::{
NodeCount,
NodeIndexable,
};
use rayon::prelude::*;
use rayon_cond::CondIterator;

/// Compute the betweenness centrality of all nodes in a graph.
Expand All @@ -46,7 +45,7 @@ use rayon_cond::CondIterator;
/// Arguments:
///
/// * `graph` - The graph object to run the algorithm on
/// * `endpoints` - Whether to include the endpoints of paths in the path
/// * `include_endpoints` - Whether to include the endpoints of paths in the path
/// lengths used to compute the betweenness
/// * `normalized` - Whether to normalize the betweenness scores by the number
/// of distinct paths between all pairs of nodes
Expand Down Expand Up @@ -75,7 +74,7 @@ use rayon_cond::CondIterator;
/// [`edge_betweenness_centrality`]
pub fn betweenness_centrality<G>(
graph: G,
endpoints: bool,
include_endpoints: bool,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this ok from a backwards compatibility perspective since in rust the arguments are all passed positionally (this would be breaking in Python so I wanted to just note this and check myself).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, that's a good point.

normalized: bool,
parallel_threshold: usize,
) -> Vec<Option<f64>>
Expand Down Expand Up @@ -111,73 +110,27 @@ where
betweenness[is] = Some(0.0);
}
let locked_betweenness = RwLock::new(&mut betweenness);
let node_indices: Vec<usize> = graph
.node_identifiers()
.map(|i| graph.to_index(i))
.collect();
if graph.node_count() < parallel_threshold {
node_indices
.iter()
.map(|node_s| {
(
shortest_path_for_centrality(&graph, &graph.from_index(*node_s)),
*node_s,
)
})
.for_each(|(mut shortest_path_calc, is)| {
if endpoints {
_accumulate_endpoints(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
is,
&graph,
);
} else {
_accumulate_basic(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
is,
&graph,
);
}
});
} else {
node_indices
.par_iter()
.map(|node_s| {
(
shortest_path_for_centrality(&graph, &graph.from_index(*node_s)),
node_s,
)
})
.for_each(|(mut shortest_path_calc, is)| {
if endpoints {
_accumulate_endpoints(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
*is,
&graph,
);
} else {
_accumulate_basic(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
*is,
&graph,
);
}
});
}
let node_indices: Vec<G::NodeId> = graph.node_identifiers().collect();

CondIterator::new(node_indices, graph.node_count() >= parallel_threshold)
.map(|node_s| (shortest_path_for_centrality(&graph, &node_s), node_s))
.for_each(|(mut shortest_path_calc, node_s)| {
_accumulate_vertices(
&locked_betweenness,
max_index,
&mut shortest_path_calc,
node_s,
&graph,
include_endpoints,
);
});

_rescale(
&mut betweenness,
graph.node_count(),
normalized,
graph.is_directed(),
endpoints,
include_endpoints,
);

betweenness
Expand Down Expand Up @@ -275,12 +228,12 @@ fn _rescale(
node_count: usize,
normalized: bool,
directed: bool,
endpoints: bool,
include_endpoints: bool,
) {
let mut do_scale = true;
let mut scale = 1.0;
if normalized {
if endpoints {
if include_endpoints {
if node_count < 2 {
do_scale = false;
} else {
Expand All @@ -303,12 +256,13 @@ fn _rescale(
}
}

fn _accumulate_basic<G>(
fn _accumulate_vertices<G>(
locked_betweenness: &RwLock<&mut Vec<Option<f64>>>,
max_index: usize,
path_calc: &mut ShortestPathData<G>,
is: usize,
node_s: <G as GraphBase>::NodeId,
graph: G,
include_endpoints: bool,
) where
G: NodeIndexable
+ IntoNodeIdentifiers
Expand All @@ -330,47 +284,22 @@ fn _accumulate_basic<G>(
}
}
let mut betweenness = locked_betweenness.write().unwrap();
for w in &path_calc.verts_sorted_by_distance {
let iw = graph.to_index(*w);
if iw != is {
betweenness[iw] = betweenness[iw].map(|x| x + delta[iw]);
}
}
}

fn _accumulate_endpoints<G>(
locked_betweenness: &RwLock<&mut Vec<Option<f64>>>,
max_index: usize,
path_calc: &mut ShortestPathData<G>,
is: usize,
graph: G,
) where
G: NodeIndexable
+ IntoNodeIdentifiers
+ IntoNeighborsDirected
+ NodeCount
+ GraphProp
+ GraphBase
+ std::marker::Sync,
<G as GraphBase>::NodeId: std::cmp::Eq + Hash,
{
let mut delta = vec![0.0; max_index];
for w in &path_calc.verts_sorted_by_distance {
let iw = graph.to_index(*w);
let coeff = (1.0 + delta[iw]) / path_calc.sigma[w];
let p_w = path_calc.predecessors.get(w).unwrap();
for v in p_w {
let iv = graph.to_index(*v);
delta[iv] += path_calc.sigma[v] * coeff;
if include_endpoints {
let i_node_s = graph.to_index(node_s);
betweenness[i_node_s] = betweenness[i_node_s]
.map(|x| x + ((path_calc.verts_sorted_by_distance.len() - 1) as f64));
for w in &path_calc.verts_sorted_by_distance {
if *w != node_s {
let iw = graph.to_index(*w);
betweenness[iw] = betweenness[iw].map(|x| x + delta[iw] + 1.0);
}
}
}
let mut betweenness = locked_betweenness.write().unwrap();
betweenness[is] =
betweenness[is].map(|x| x + ((path_calc.verts_sorted_by_distance.len() - 1) as f64));
for w in &path_calc.verts_sorted_by_distance {
let iw = graph.to_index(*w);
if iw != is {
betweenness[iw] = betweenness[iw].map(|x| x + delta[iw] + 1.0);
} else {
for w in &path_calc.verts_sorted_by_distance {
if *w != node_s {
let iw = graph.to_index(*w);
betweenness[iw] = betweenness[iw].map(|x| x + delta[iw]);
}
}
}
}
Expand Down