Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mpl_draw() to fix multigraph plots #1204

Merged
merged 13 commits into from
Jun 10, 2024
18 changes: 17 additions & 1 deletion rustworkx/visualization/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,16 @@ def _connectionstyle(posA, posB, *args, **kwargs):
else:
line_width = width

# radius of edges

reverse_edge = np.concatenate(([dst], [src]))
for edge in edge_pos: # the loop can be optimized
if bool(np.sum(np.all(np.equal(edge, reverse_edge)))):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are easier and simpler ways to write this. I don't quite understand why we are using NumPy for this check either

rad = 0.25
break
else:
rad = 0

arrow = mpl.patches.FancyArrowPatch(
(x1, y1),
(x2, y2),
Expand All @@ -763,7 +773,7 @@ def _connectionstyle(posA, posB, *args, **kwargs):
mutation_scale=mutation_scale,
color=arrow_color,
linewidth=line_width,
connectionstyle=_connectionstyle,
connectionstyle=connectionstyle + f", rad = {rad}",
maxwell04-wq marked this conversation as resolved.
Show resolved Hide resolved
linestyle=style,
zorder=1,
) # arrows go behind nodes
Expand Down Expand Up @@ -1001,6 +1011,12 @@ def draw_edge_labels(
x1 * label_pos + x2 * (1.0 - label_pos),
y1 * label_pos + y2 * (1.0 - label_pos),
)
if (n2, n1) in labels.keys(): # loop
x += 0.05 * label_pos
if n2 > n1:
y -= 0.25
else:
y += 0.25

if rotate:
# in degrees
Expand Down