Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add find_predecessor_node_by_edge. #756

Merged
merged 15 commits into from
Dec 13, 2022
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Add a method
:meth:`~rustworkx.DiGraph.find_predecessor_node_by_edge` to get
the immediate predecessor of a node which is connected by the
specified edge.
36 changes: 36 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,42 @@ impl PyDiGraph {
Err(NoSuitableNeighbors::new_err("No suitable neighbor"))
}

/// Find a source node with a specific edge
///
/// This method is used to find a predecessor of
/// a given node given an edge condition.
///
/// :param int node: The node to use as the source of the search
/// :param callable predicate: A python callable that will take a single
/// parameter, the edge object, and will return a boolean if the
/// edge matches or not
///
/// :returns: The node object that has an edge from it to the provided
/// node index which matches the provided condition
#[pyo3(text_signature = "(self, node, predicate, /)")]
pub fn find_predecessor_node_by_edge(
&self,
py: Python,
node: usize,
predicate: PyObject,
) -> PyResult<&PyObject> {
let predicate_callable = |a: &PyObject| -> PyResult<PyObject> {
let res = predicate.call1(py, (a,))?;
Ok(res.to_object(py))
};
let index = NodeIndex::new(node);
let dir = petgraph::Direction::Incoming;
let edges = self.graph.edges_directed(index, dir);
for edge in edges {
let edge_predicate_raw = predicate_callable(edge.weight())?;
let edge_predicate: bool = edge_predicate_raw.extract(py)?;
if edge_predicate {
return Ok(self.graph.node_weight(edge.source()).unwrap());
}
}
Err(NoSuitableNeighbors::new_err("No suitable neighbor"))
}

/// Generate a dot file from the graph
///
/// :param node_attr: A callable that will take in a node data object
Expand Down
26 changes: 26 additions & 0 deletions tests/rustworkx_tests/digraph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,32 @@ def compare_edges(edge):
with self.assertRaises(rustworkx.NoSuitableNeighbors):
dag.find_adjacent_node_by_edge(node_a, compare_edges)

def test_find_predecessor_node_by_edge(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
node_b = dag.add_child(node_a, "b", "a to b")
node_c = dag.add_child(node_b, "c", "b to c")
dag.add_child(node_c, "d", "c to d")

def compare_edges(edge):
return "a to b" == edge

res = dag.find_predecessor_node_by_edge(node_b, compare_edges)
self.assertEqual("a", res)

def test_find_predecessor_node_by_edge_no_match(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
node_b = dag.add_child(node_a, "b", "a to b")
node_c = dag.add_child(node_b, "c", "b to c")
dag.add_child(node_c, "d", "c to d")

def compare_edges(edge):
return "b to c" == edge

with self.assertRaises(rustworkx.NoSuitableNeighbors):
dag.find_predecessor_node_by_edge(node_b, compare_edges)

def test_add_edge_from(self):
dag = rustworkx.PyDAG()
nodes = list(range(4))
Expand Down