diff --git a/rustworkx-core/src/generators/mod.rs b/rustworkx-core/src/generators/mod.rs index 3360ca6a9..b4618ad91 100644 --- a/rustworkx-core/src/generators/mod.rs +++ b/rustworkx-core/src/generators/mod.rs @@ -14,6 +14,7 @@ mod cycle_graph; mod grid_graph; +mod path_graph; mod star_graph; mod utils; @@ -35,4 +36,5 @@ impl fmt::Display for InvalidInputError { pub use cycle_graph::cycle_graph; pub use grid_graph::grid_graph; +pub use path_graph::path_graph; pub use star_graph::star_graph; diff --git a/rustworkx-core/src/generators/path_graph.rs b/rustworkx-core/src/generators/path_graph.rs new file mode 100644 index 000000000..2317e47e8 --- /dev/null +++ b/rustworkx-core/src/generators/path_graph.rs @@ -0,0 +1,153 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use petgraph::data::{Build, Create}; +use petgraph::visit::{Data, NodeIndexable}; + +use super::utils::get_num_nodes; +use super::InvalidInputError; + +/// Generate a path graph +/// +/// Arguments: +/// +/// * `num_nodes` - The number of nodes to create a path graph for. Either this or +/// `weights must be specified. If both this and `weights are specified, weights +/// will take priorty and this argument will be ignored +/// * `weights` - A `Vec` of node weight objects. +/// * `default_node_weight` - A callable that will return the weight to use +/// for newly created nodes. This is ignored if `weights` is specified, +/// as the weights from that argument will be used instead. +/// * `default_edge_weight` - A callable that will return the weight object +/// to use for newly created edges. +/// * `bidirectional` - Whether edges are added bidirectionally, if set to +/// `true` then for any edge `(u, v)` an edge `(v, u)` will also be added. +/// If the graph is undirected this will result in a pallel edge. +/// +/// # Example +/// ```rust +/// use rustworkx_core::petgraph; +/// use rustworkx_core::generators::path_graph; +/// use rustworkx_core::petgraph::visit::EdgeRef; +/// +/// let g: petgraph::graph::UnGraph<(), ()> = path_graph( +/// Some(4), +/// None, +/// || {()}, +/// || {()}, +/// false +/// ).unwrap(); +/// assert_eq!( +/// vec![(0, 1), (1, 2), (2, 3)], +/// g.edge_references() +/// .map(|edge| (edge.source().index(), edge.target().index())) +/// .collect::>(), +/// ) +/// ``` +pub fn path_graph( + num_nodes: Option, + weights: Option>, + mut default_node_weight: F, + mut default_edge_weight: H, + bidirectional: bool, +) -> Result +where + G: Build + Create + Data + NodeIndexable, + F: FnMut() -> T, + H: FnMut() -> M, +{ + if weights.is_none() && num_nodes.is_none() { + return Err(InvalidInputError {}); + } + let node_len = get_num_nodes(&num_nodes, &weights); + let num_edges = if bidirectional { + 2 * node_len + } else { + node_len + }; + let mut graph = G::with_capacity(node_len, num_edges); + if node_len == 0 { + return Ok(graph); + } + + match weights { + Some(weights) => { + for weight in weights { + graph.add_node(weight); + } + } + None => { + for _ in 0..node_len { + graph.add_node(default_node_weight()); + } + } + }; + for a in 0..node_len - 1 { + let node_a = graph.from_index(a); + let node_b = graph.from_index(a + 1); + graph.add_edge(node_a, node_b, default_edge_weight()); + if bidirectional { + graph.add_edge(node_b, node_a, default_edge_weight()); + } + } + Ok(graph) +} + +#[cfg(test)] +mod tests { + use crate::generators::path_graph; + use crate::generators::InvalidInputError; + use crate::petgraph; + use crate::petgraph::visit::EdgeRef; + + #[test] + fn test_with_weights() { + let g: petgraph::graph::UnGraph = + path_graph(None, Some(vec![0, 1, 2, 3]), || 4, || (), false).unwrap(); + assert_eq!( + vec![(0, 1), (1, 2), (2, 3)], + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + assert_eq!( + vec![0, 1, 2, 3], + g.node_weights().copied().collect::>(), + ); + } + + #[test] + fn test_bidirectional() { + let g: petgraph::graph::DiGraph<(), ()> = + path_graph(Some(4), None, || (), || (), true).unwrap(); + assert_eq!( + vec![(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2),], + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + } + + #[test] + fn test_error() { + match path_graph::, (), _, _, ()>( + None, + None, + || (), + || (), + false, + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } +} diff --git a/src/generators.rs b/src/generators.rs index c6fa47a18..acba83abf 100644 --- a/src/generators.rs +++ b/src/generators.rs @@ -192,47 +192,21 @@ pub fn directed_path_graph( bidirectional: bool, multigraph: bool, ) -> PyResult { - if weights.is_none() && num_nodes.is_none() { - return Err(PyIndexError::new_err( - "num_nodes and weights list not specified", - )); - } - let node_len = get_num_nodes(&num_nodes, &weights); - let num_edges = if bidirectional { - 2 * node_len - } else { - node_len - }; - let mut graph = StablePyGraph::::with_capacity(node_len, num_edges); - if node_len == 0 { - return Ok(digraph::PyDiGraph { - graph, - node_removed: false, - check_cycle: false, - cycle_state: algo::DfsSpace::default(), - multigraph, - attrs: py.None(), - }); - } - - match weights { - Some(weights) => { - for weight in weights { - graph.add_node(weight); - } + let default_fn = || py.None(); + let graph: StablePyGraph = match core_generators::path_graph( + num_nodes, + weights, + default_fn, + default_fn, + bidirectional, + ) { + Ok(graph) => graph, + Err(_) => { + return Err(PyIndexError::new_err( + "num_nodes and weights list not specified", + )) } - None => (0..node_len).for_each(|_| { - graph.add_node(py.None()); - }), }; - for a in 0..node_len - 1 { - let node_b = NodeIndex::new(a + 1); - let node_a = NodeIndex::new(a); - graph.add_edge(node_a, node_b, py.None()); - if bidirectional { - graph.add_edge(node_b, node_a, py.None()); - } - } Ok(digraph::PyDiGraph { graph, node_removed: false, @@ -276,35 +250,16 @@ pub fn path_graph( weights: Option>, multigraph: bool, ) -> PyResult { - if weights.is_none() && num_nodes.is_none() { - return Err(PyIndexError::new_err( - "num_nodes and weights list not specified", - )); - } - let node_len = get_num_nodes(&num_nodes, &weights); - let mut graph = StablePyGraph::::with_capacity(node_len, node_len); - if node_len == 0 { - return Ok(graph::PyGraph { - graph, - node_removed: false, - multigraph, - attrs: py.None(), - }); - } - match weights { - Some(weights) => { - for weight in weights { - graph.add_node(weight); + let default_fn = || py.None(); + let graph: StablePyGraph = + match core_generators::path_graph(num_nodes, weights, default_fn, default_fn, false) { + Ok(graph) => graph, + Err(_) => { + return Err(PyIndexError::new_err( + "num_nodes and weights list not specified", + )) } - } - None => (0..node_len).for_each(|_| { - graph.add_node(py.None()); - }), - }; - for node_a in 0..node_len - 1 { - let node_b = NodeIndex::new(node_a + 1); - graph.add_edge(NodeIndex::new(node_a), node_b, py.None()); - } + }; Ok(graph::PyGraph { graph, node_removed: false,