Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method: has_node #1169

Merged
merged 4 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions releasenotes/notes/add-has_node-method-9e6b91bf79e60f50.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Added a method :meth:`~rustworkx.PyGraph.has_node`
to the
:class:`~rustworkx.PyGraph` and :class:`~rustworkx.PyDiGraph`
classes to check if a node is in the graph.
2 changes: 2 additions & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@ class PyGraph(Generic[_S, _T]):
def get_edge_data_by_index(self, edge_index: int, /) -> _T: ...
def get_edge_endpoints_by_index(self, edge_index: int, /) -> tuple[int, int]: ...
def get_node_data(self, node: int, /) -> _S: ...
def has_node(self, node: int, /) -> bool: ...
def has_edge(self, node_a: int, node_b: int, /) -> bool: ...
def has_parallel_edges(self) -> bool: ...
def in_edges(self, node: int, /) -> WeightedEdgeList[_T]: ...
Expand Down Expand Up @@ -1303,6 +1304,7 @@ class PyDiGraph(Generic[_S, _T]):
def get_node_data(self, node: int, /) -> _S: ...
def get_edge_data_by_index(self, edge_index: int, /) -> _T: ...
def get_edge_endpoints_by_index(self, edge_index: int, /) -> tuple[int, int]: ...
def has_node(self, node: int, /) -> bool: ...
def has_edge(self, node_a: int, node_b: int, /) -> bool: ...
def has_parallel_edges(self) -> bool: ...
def in_degree(self, node: int, /) -> int: ...
Expand Down
18 changes: 15 additions & 3 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,18 @@ impl PyDiGraph {
self.node_indices()
}

/// Return True if there is a node in the graph.
///
/// :param int node: The node index to check
///
/// :returns: True if there is a node false if there is no node
/// :rtype: bool
#[pyo3(text_signature = "(self, node, /)")]
pub fn has_node(&self, node: usize) -> bool {
let index = NodeIndex::new(node);
self.graph.contains_node(index)
}

/// Return True if there is an edge from node_a to node_b.
///
/// :param int node_a: The source node index to check for an edge
Expand Down Expand Up @@ -3059,7 +3071,7 @@ impl PyDiGraph {
/// required to return a boolean value stating whether the node's data payload fits some criteria.
///
/// For example::
///
///
/// from rustworkx import PyDiGraph
///
/// graph = PyDiGraph()
Expand Down Expand Up @@ -3107,8 +3119,8 @@ impl PyDiGraph {
/// def my_filter_function(edge):
/// if edge:
/// return edge == 'B'
/// return False
///
/// return False
///
/// indices = graph.filter_edges(my_filter_function)
/// assert indices == [1]
///
Expand Down
18 changes: 15 additions & 3 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,18 @@ impl PyGraph {
self.node_indices()
}

/// Return True if there is a node.
///
/// :param int node: The index for the node
///
/// :returns: True if there is a node false if there is no node
/// :rtype: bool
#[pyo3(text_signature = "(self, node, /)")]
pub fn has_node(&self, node: usize) -> bool {
let index = NodeIndex::new(node);
self.graph.contains_node(index)
}

/// Return True if there is an edge between ``node_a`` and ``node_b``.
///
/// :param int node_a: The index for the first node
Expand Down Expand Up @@ -2039,7 +2051,7 @@ impl PyGraph {
/// required to return a boolean value stating whether the node's data payload fits some criteria.
///
/// For example::
///
///
/// from rustworkx import PyGraph
///
/// graph = PyGraph()
Expand Down Expand Up @@ -2087,8 +2099,8 @@ impl PyGraph {
/// def my_filter_function(edge):
/// if edge:
/// return edge == 'B'
/// return False
///
/// return False
///
/// indices = graph.filter_edges(my_filter_function)
/// assert indices == [1]
///
Expand Down
6 changes: 6 additions & 0 deletions tests/digraph/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def test_remove_nodes_from_with_invalid_index(self):
self.assertEqual(["a"], res)
self.assertEqual([0], dag.node_indexes())

def test_has_node(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
self.assertTrue(dag.has_node(node_a))
self.assertFalse(dag.has_node(node_a + 1))
IvanIsCoding marked this conversation as resolved.
Show resolved Hide resolved

def test_remove_nodes_retain_edges_single_edge(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
Expand Down
6 changes: 6 additions & 0 deletions tests/graph/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,9 @@ def test_remove_node_delitem_invalid_index(self):
res = graph.nodes()
self.assertEqual(["a", "b", "c"], res)
self.assertEqual([0, 1, 2], graph.node_indexes())

def test_has_node(self):
graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
self.assertTrue(graph.has_node(node_a))
self.assertFalse(graph.has_node(node_a + 1))
Loading