Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Graph View Visualization #22

Merged
merged 5 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

[![python](https://img.shields.io/badge/python-3.10%2B-blue)]() [![pytorch](https://img.shields.io/badge/pytorch-2.0%2B-orange)]() [![Downloads](https://static.pepy.tech/personalized-badge/visualtorch?period=total&units=international_system&left_color=grey&right_color=green&left_text=PyPI%20Downloads)](https://pepy.tech/project/visualtorch) [![Run Tests](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml/badge.svg)](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml)

**VisualTorch** aims to help visualize Torch-based neural network architectures. Currently, this package supports generating layered-style architectures for Torch Sequential and Custom models. This package is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras) by [@paulgavrikov](https://github.com/paulgavrikov).
**VisualTorch** aims to help visualize Torch-based neural network architectures. Currently, this package supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This package is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).

**v0.2**: Support for custom models has been added.
**v0.2**: Added support for custom models and implemented graph view functionality.

**v0.1.1**: Support for the layered architecture of Torch Sequential.
**v0.1.1**: Added support for the layered architecture of Torch Sequential.

## Installation

Expand Down Expand Up @@ -100,6 +100,37 @@ visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # d

![simple-cnn-custom](https://github.com/willyfh/visualtorch/assets/5786636/f22298b4-f341-4a0d-b85b-11f01e207ad8)

### Graph View

```python
import torch
import torch.nn as nn
import visualtorch

class SimpleDense(nn.Module):
def __init__(self):
super(SimpleDense, self).__init__()
self.h0 = nn.Linear(4, 8)
self.h1 = nn.Linear(8, 8)
self.h2 = nn.Linear(8, 4)
self.out = nn.Linear(4, 2)

def forward(self, x):
x = self.h0(x)
x = self.h1(x)
x = self.h2(x)
x = self.out(x)
return x

model = SimpleDense()

input_shape = (1, 4)

visualtorch.graph_view(model, input_shape).show()
```

![graph-view](https://github.com/willyfh/visualtorch/assets/5786636/a65b4208-72da-497b-b6c9-aafc82b67b58)

### Save the Image

```python
Expand Down Expand Up @@ -142,7 +173,7 @@ Please feel free to send a pull request to contribute to this project.

This poject is available as open source under the terms of the [MIT License](https://github.com/willyfh/visualtorch/blob/update-readme/LICENSE).

Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license).
Originally, this project was based on the [visualkeras](https://github.com/paulgavrikov/visualkeras) (under the MIT license), with additional modifications inspired by [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary), both of which are also licensed under the MIT license.

## Citation

Expand Down
3 changes: 2 additions & 1 deletion visualtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from visualtorch.layered import layered_view
from visualtorch.graph import graph_view

__all__ = ["layered_view"]
__all__ = ["layered_view", "graph_view"]
196 changes: 196 additions & 0 deletions visualtorch/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import aggdraw
from PIL import Image
from math import ceil
from .layer_utils import model_to_adj_matrix, add_input_dummy_layer
from .utils import Circle, Ellipses, get_keys_by_value, Box
import numpy as np
from typing import Optional, Dict, Any, Tuple, List
import torch


def graph_view(
model: torch.nn.Module,
input_shape: Tuple[int, ...],
to_file: Optional[str] = None,
color_map: Optional[Dict[Any, Any]] = None,
node_size: int = 50,
background_fill: Any = "white",
padding: int = 10,
layer_spacing: int = 250,
node_spacing: int = 10,
connector_fill: Any = "gray",
connector_width: int = 1,
ellipsize_after: int = 10,
inout_as_tensor: bool = True,
show_neurons: bool = True,
) -> Image.Image:
"""
Generates an architecture visualization for a given linear PyTorch model (i.e., one input and output tensor for each
layer) in a graph style.

Args:
model (torch.nn.Module): A PyTorch model that will be visualized.
input_shape (tuple): The shape of the input tensor.
to_file (str, optional): Path to the file to write the created image to. If the image does not exist yet,
it will be created, else overwritten. Image type is inferred from the file ending. Providing None
will disable writing.
color_map (dict, optional): Dict defining fill and outline for each layer by class type. Will fallback to default
values for not specified classes.
node_size (int, optional): Size in pixels each node will have.
background_fill (Any, optional): Color for the image background. Can be str or (R,G,B,A).
padding (int, optional): Distance in pixels before the first and after the last layer.
layer_spacing (int, optional): Spacing in pixels between two layers.
node_spacing (int, optional): Spacing in pixels between nodes.
connector_fill (Any, optional): Color for the connectors. Can be str or (R,G,B,A).
connector_width (int, optional): Line-width of the connectors in pixels.
ellipsize_after (int, optional): Maximum number of neurons per layer to draw. If a layer is exceeding this,
the remaining neurons will be drawn as ellipses.
inout_as_tensor (bool, optional): If True there will be one input and output node for each tensor, else the
tensor will be flattened and one node for each scalar will be created (e.g., a (10, 10) shape will be
represented by 100 nodes).
show_neurons (bool, optional): If True a node for each neuron in supported layers is created (constrained by
ellipsize_after), else each layer is represented by a node.

Returns:
Image.Image: Generated architecture image.
"""

if color_map is None:
color_map = dict()

# Iterate over the model to compute bounds and generate boxes

layers: List[Any] = list()
layer_y = list()

# Attach helper layers

id_to_num_mapping, adj_matrix, model_layers = model_to_adj_matrix(
model, input_shape
)

# Add fake input layers

id_to_num_mapping, adj_matrix, model_layers = add_input_dummy_layer(
input_shape, id_to_num_mapping, adj_matrix, model_layers
)

# Create architecture

current_x = padding # + input_label_size[0] + text_padding

id_to_node_list_map = dict()

for index, layer_list in enumerate(model_layers):
current_y = 0
nodes = []
for layer in layer_list:
is_box = True
units = 1

if show_neurons:
if hasattr(layer, "_saved_bias_sym_sizes_opt"):
is_box = False
units = layer._saved_bias_sym_sizes_opt[0]
elif hasattr(layer, "_saved_mat2_sym_sizes"):
is_box = False
units = layer._saved_mat2_sym_sizes[1]
elif hasattr(layer, "units"): # for dummy input layer
is_box = False
units = layer.units

n = min(units, ellipsize_after)
layer_nodes = list()

for i in range(n):
scale = 1
c: Box | Circle | Ellipses
if not is_box:
if i != ellipsize_after - 2:
c = Circle()
else:
c = Ellipses()
else:
c = Box()
scale = 3

c.x1 = current_x
c.y1 = current_y
c.x2 = c.x1 + node_size
c.y2 = c.y1 + node_size * scale

current_y = c.y2 + node_spacing

c.fill = color_map.get(type(layer), {}).get("fill", "blue")
c.outline = color_map.get(type(layer), {}).get("outline", "black")

layer_nodes.append(c)

id_to_node_list_map[str(id(layer))] = layer_nodes
nodes.extend(layer_nodes)
current_y += 2 * node_size

layer_y.append(current_y - node_spacing - 2 * node_size)
layers.append(nodes)
current_x += node_size + layer_spacing

# Generate image

img_width = (
len(layers) * node_size + (len(layers) - 1) * layer_spacing + 2 * padding
)
img_height = max(*layer_y) + 2 * padding
img = Image.new(
"RGBA", (int(ceil(img_width)), int(ceil(img_height))), background_fill
)

draw = aggdraw.Draw(img)

# y correction (centering)
for i, layer in enumerate(layers):
y_off = (img.height - layer_y[i]) / 2
node: Any
for node in layer:
node.y1 += y_off
node.y2 += y_off

for start_idx, end_idx in zip(*np.where(adj_matrix > 0)):
start_id = next(get_keys_by_value(id_to_num_mapping, start_idx))
end_id = next(get_keys_by_value(id_to_num_mapping, end_idx))

start_layer_list = id_to_node_list_map[start_id]
end_layer_list = id_to_node_list_map[end_id]

# draw connectors
for start_node_idx, start_node in enumerate(start_layer_list):
for end_node in end_layer_list:
if not isinstance(start_node, Ellipses) and not isinstance(
end_node, Ellipses
):
_draw_connector(
draw,
start_node,
end_node,
color=connector_fill,
width=connector_width,
)

for i, layer in enumerate(layers):
for node_index, node in enumerate(layer):
node.draw(draw)

draw.flush()

if to_file is not None:
img.save(to_file)

return img


def _draw_connector(draw, start_node, end_node, color, width):
pen = aggdraw.Pen(color, width)
x1 = start_node.x2
y1 = start_node.y1 + (start_node.y2 - start_node.y1) / 2
x2 = end_node.x1
y2 = end_node.y1 + (end_node.y2 - end_node.y1) / 2
draw.line([x1, y1, x2, y2], pen)
Loading
Loading