Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement substitute_node_with_subgraph to Pygraph #894

Merged
merged 14 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
---
features:
- |
Added method substitute_node_with_subgraph to the PyGraph class.

.. jupyter-execute::

import rustworkx
from rustworkx.visualization import * # Needs matplotlib/

graph = rustworkx.generators.complete_graph(5)
sub_graph = rustworkx.generators.path_graph(3)

# Replace node 4 in this graph with sub_graph
# Make sure to connect the graphs at node 2 of the sub_graph
# This is done by passing a function that returns 2

graph.substitute_node_with_subgraph(4, sub_graph, lambda _, __, ___: 2)

# Draw the updated graph
mpl_draw(graph, with_labels=True)
fixes:
- |
Fixes missing method that is present in PyDiGraph but not in PyGraph.
see `#837 <https://github.com/Qiskit/rustworkx/issues/837>`__ for more info.
9 changes: 9 additions & 0 deletions rustworkx/graph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ class PyGraph(Generic[S, T]):
def remove_node(self, node: int, /) -> None: ...
def remove_nodes_from(self, index_list: Sequence[int], /) -> None: ...
def subgraph(self, nodes: Sequence[int], /, preserve_attrs: bool = ...) -> PyGraph[S, T]: ...
def substitute_node_with_subgraph(
self,
node: int,
other: PyGraph[S, T],
edge_map_fn: Callable[[int, int, T], Optional[int]],
/,
node_filter: Optional[Callable[[S], bool]] = ...,
edge_weight_map: Optional[Callable[[T], T]] = ...,
) -> NodeMap: ...
def to_dot(
self,
/,
Expand Down
129 changes: 129 additions & 0 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ use num_traits::Zero;
use numpy::Complex64;
use numpy::PyReadonlyArray2;

use crate::iterators::NodeMap;

use super::dot_utils::build_dot;
use super::iterators::{EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList};
use super::{
Expand Down Expand Up @@ -1594,6 +1596,133 @@ impl PyGraph {
Ok(out_dict.into())
}

/// Docs Pending
///
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
#[pyo3(
text_signature = "(self, node, other, edge_map_fn, /, node_filter=None, edge_weight_map=None"
)]
fn substitute_node_with_subgraph(
&mut self,
py: Python,
node: usize,
other: &PyGraph,
edge_map_fn: PyObject,
node_filter: Option<PyObject>,
edge_weight_map: Option<PyObject>,
) -> PyResult<NodeMap> {
let filter_fn = |obj: &PyObject, filter_fn: &Option<PyObject>| -> PyResult<bool> {
match filter_fn {
Some(filter) => {
let res = filter.call1(py, (obj,)).unwrap();
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
res.extract(py)
}
None => Ok(true),
}
};

let weight_map_fn = |obj: &PyObject, weight_fn: &Option<PyObject>| -> PyResult<PyObject> {
match weight_fn {
Some(weight_fn) => weight_fn.call1(py, (obj,)),
None => Ok(obj.clone_ref(py)),
}
};

let map_fn = |source: usize, target: usize, weight: &PyObject| -> PyResult<Option<usize>> {
let res = edge_map_fn.call1(py, (source, target, weight)).unwrap();
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
res.extract(py)
};

let node_index = NodeIndex::new(node);
if self.graph.node_weight(node_index).is_none() {
return Err(PyIndexError::new_err(format!(
"Specified node {} is not in this graph",
node
)));
}

// Copy all nodes from other to self
let mut out_map: DictMap<usize, usize> = DictMap::with_capacity(other.node_count());
for node in other.graph.node_indices() {
let node_weight: Py<PyAny> = other.graph[node].clone_ref(py);
if !filter_fn(&node_weight, &node_filter)? {
continue;
}
let new_index: NodeIndex = self.graph.add_node(node_weight);
out_map.insert(node.index(), new_index.index());
}

if out_map.is_empty() {
self.graph.remove_node(node_index);
return Ok(NodeMap {
node_map: DictMap::new(),
});
}

// Copy all edges
for edge in other.graph.edge_references().filter(|edge| {
out_map.contains_key(&edge.target().index())
&& out_map.contains_key(&edge.source().index())
}) {
self._add_edge(
NodeIndex::new(out_map[&edge.source().index()]),
NodeIndex::new(out_map[&edge.target().index()]),
weight_map_fn(edge.weight(), &edge_weight_map).unwrap(),
);
}
// Incoming and outgoing edges.
let in_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self
.graph
.edge_references()
.filter(|edge| edge.target() == node_index)
.map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py)))
.collect();
// Keep track of what's present on incoming edges
let in_set: HashSet<(NodeIndex, NodeIndex)> =
in_edges.iter().map(|edge| (edge.0, edge.1)).collect();
// Retrieve outgoing edges. Make sure to not include any incoming edge.
let out_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self
.graph
.edges(node_index)
.filter(|edge| !in_set.contains(&(edge.target(), edge.source())))
.map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py)))
.collect();
for (source, target, weight) in in_edges {
let old_index: Option<usize> = map_fn(source.index(), target.index(), &weight).unwrap();
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
let target_out: NodeIndex = match old_index {
Some(old_index) => match out_map.get(&old_index) {
Some(new_index) => NodeIndex::new(*new_index),
None => {
return Err(PyIndexError::new_err(format!(
"No matter index {} found",
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
old_index
)))
}
},
None => continue,
};
self._add_edge(source, target_out, weight);
}
for (source, target, weight) in out_edges {
let old_index: Option<usize> = map_fn(source.index(), target.index(), &weight).unwrap();
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
let source_out: NodeIndex = match old_index {
Some(old_index) => match out_map.get(&old_index) {
Some(new_index) => NodeIndex::new(*new_index),
None => {
return Err(PyIndexError::new_err(format!(
"No matter index {} found",
raynelfss marked this conversation as resolved.
Show resolved Hide resolved
old_index
)))
}
},
None => continue,
};
self._add_edge(source_out, target, weight);
}
// Remove original node
self.graph.remove_node(node_index);
Ok(NodeMap { node_map: out_map })
}

/// Substitute a set of nodes with a single new node.
///
/// .. note::
Expand Down
141 changes: 141 additions & 0 deletions tests/rustworkx_tests/graph/test_substitute_node_with_subgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest
import rustworkx


class TestSubstituteNodeSubGraph(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
self.graph = rustworkx.generators.path_graph(5)

def test_empty_replacement(self):
in_graph = rustworkx.PyGraph()
res = self.graph.substitute_node_with_subgraph(3, in_graph, lambda _, __, ___: None)
self.assertEqual(res, {})
self.assertEqual([(0, 1), (1, 2)], self.graph.edge_list())

def test_single_node(self):
in_graph = rustworkx.generators.path_graph(1)
res = self.graph.substitute_node_with_subgraph(2, in_graph, lambda _, __, ___: 0)
self.assertEqual(res, {0: 5})
self.assertEqual([(0, 1), (1, 5), (3, 4), (5, 3)], sorted(self.graph.edge_list()))

def test_node_filter(self):
in_graph = rustworkx.generators.complete_graph(5)
res = self.graph.substitute_node_with_subgraph(
0, in_graph, lambda _, __, ___: 2, node_filter=lambda node: node == None
)
self.assertEqual(res, {i: i + 5 for i in range(5)})
self.assertEqual(
[
(1, 2),
(2, 3),
(3, 4),
(5, 6),
(5, 7),
(5, 8),
(5, 9),
(6, 7),
(6, 8),
(6, 9),
(7, 1),
(7, 8),
(7, 9),
(8, 9),
],
sorted(self.graph.edge_list()),
)

def test_edge_weight_modifier(self):
in_graph = rustworkx.PyGraph()
in_graph.add_node("meep")
in_graph.add_node("moop")
in_graph.add_edges_from(
[
(
0,
1,
"edge",
)
]
)
res = self.graph.substitute_node_with_subgraph(
2,
in_graph,
lambda _, __, ___: 0,
edge_weight_map=lambda edge: edge + "-migrated",
)
self.assertEqual([(0, 1), (3, 4), (5, 6), (1, 5), (5, 3)], self.graph.edge_list())
self.assertEqual("edge-migrated", self.graph.get_edge_data(5, 6))
self.assertEqual(res, {0: 5, 1: 6})

def test_none_mapping(self):
in_graph = rustworkx.PyGraph()
in_graph.add_node("boop")
in_graph.add_node("beep")
in_graph.add_edges_from([(0, 1, "edge")])
res = self.graph.substitute_node_with_subgraph(2, in_graph, lambda _, __, ___: None)
self.assertEqual([(0, 1), (3, 4), (5, 6)], self.graph.edge_list())
self.assertEqual(res, {0: 5, 1: 6})

def test_multiple_mapping(self):
graph = rustworkx.generators.star_graph(5)
in_graph = rustworkx.generators.star_graph(3)

def map_function(_source, target, _weight):
if target > 2:
return 2
return 1

res = graph.substitute_node_with_subgraph(0, in_graph, map_function)
self.assertEqual({0: 5, 1: 6, 2: 7}, res)
expected = [(5, 6), (5, 7), (7, 4), (7, 3), (6, 2), (6, 1)]
self.assertEqual(sorted(expected), sorted(graph.edge_list()))

def test_multiple_mapping_full(self):
graph = rustworkx.generators.star_graph(5)
in_graph = rustworkx.generators.star_graph(weights=list(range(3)))
in_graph.add_edge(1, 2, None)

def map_function(source, target, _weight):
if target > 2:
return 2
return 1

def filter_fn(node):
return node > 0

def map_weight(_):
return "migrated"

res = graph.substitute_node_with_subgraph(0, in_graph, map_function, filter_fn, map_weight)
self.assertEqual({1: 5, 2: 6}, res)
expected = [
(5, 6, "migrated"),
(6, 4, None),
(6, 3, None),
(5, 2, None),
(5, 1, None),
]
self.assertEqual(expected, graph.weighted_edge_list())

def test_invalid_target(self):
in_graph = rustworkx.generators.grid_graph(5, 5)
with self.assertRaises(IndexError):
self.graph.substitute_node_with_subgraph(0, in_graph, lambda *args: 42)

def test_invalid_node_id(self):
in_graph = rustworkx.generators.grid_graph(5, 5)
with self.assertRaises(IndexError):
self.graph.substitute_node_with_subgraph(16, in_graph, lambda *args: None)