Skip to content

Commit

Permalink
Update Colors and Readme (#23)
Browse files Browse the repository at this point in the history
* Change screenshots with new colors

* change default colors

* add image preview

* Update readme

* modified by prettier
  • Loading branch information
willyfh authored Feb 25, 2024
1 parent 47acc2a commit 0b99f11
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
31 changes: 19 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# ⭐ VisualTorch ⭐
<div align="center">
<h1>🔥 VisualTorch 🔥</h1>

[![python](https://img.shields.io/badge/python-3.10%2B-blue)]() [![pytorch](https://img.shields.io/badge/pytorch-2.0%2B-orange)]() [![Downloads](https://static.pepy.tech/personalized-badge/visualtorch?period=total&units=international_system&left_color=grey&right_color=green&left_text=PyPI%20Downloads)](https://pepy.tech/project/visualtorch) [![Run Tests](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml/badge.svg)](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml)

**VisualTorch** aims to help visualize Torch-based neural network architectures. Currently, this package supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This package is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).
</div>

**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary).

**Note:** VisualTorch may not yet support complex models, but contributions are welcome!

![layered-and-graph](https://github.com/willyfh/visualtorch/assets/5786636/694e6e6c-58ea-46d6-9280-348337c08ec7)

**v0.2**: Added support for custom models and implemented graph view functionality.

Expand Down Expand Up @@ -52,7 +59,7 @@ input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer
```

![simple-cnn-sequential](https://github.com/willyfh/visualtorch/assets/5786636/9b646fac-c336-4253-ac01-8f3e6b2fcc0b)
![simple-cnn](https://github.com/willyfh/visualtorch/assets/5786636/e8da2a52-66c6-4fda-85f8-7243702fd1f2)

### Custom Model

Expand Down Expand Up @@ -98,7 +105,7 @@ input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer
```

![simple-cnn-custom](https://github.com/willyfh/visualtorch/assets/5786636/f22298b4-f341-4a0d-b85b-11f01e207ad8)
![simple-cnn-custom](https://github.com/willyfh/visualtorch/assets/5786636/9f18db76-838d-4cd1-87ac-3ac5d3509423)

### Graph View

Expand Down Expand Up @@ -129,7 +136,7 @@ input_shape = (1, 4)
visualtorch.graph_view(model, input_shape).show()
```

![graph-view](https://github.com/willyfh/visualtorch/assets/5786636/a65b4208-72da-497b-b6c9-aafc82b67b58)
![graph](https://github.com/willyfh/visualtorch/assets/5786636/9868f8be-7bfb-4892-ad3b-72de56955c75)

### Save the Image

Expand All @@ -143,7 +150,7 @@ visualtorch.layered_view(model, input_shape=input_shape, legend=True, to_file='o
visualtorch.layered_view(model, input_shape=input_shape, draw_volume=False)
```

![2d-view](https://github.com/willyfh/visualtorch/assets/5786636/5b16c252-f589-4b3f-8ea4-1bc188e6c124)
![2d-view](https://github.com/willyfh/visualtorch/assets/5786636/71848bfa-5447-4e66-bf4c-84f9e51a581e)

### Custom Color

Expand All @@ -153,17 +160,17 @@ Use 'fill' to change the color of the layer, and use 'outline' to change the col
from collections import defaultdict

color_map = defaultdict(dict)
color_map[nn.Conv2d]['fill'] = '#FF6F61' # Coral red
color_map[nn.ReLU]['fill'] = 'skyblue'
color_map[nn.MaxPool2d]['fill'] = '#88B04B' # Sage green
color_map[nn.Flatten]['fill'] = 'gold'
color_map[nn.Linear]['fill'] = '#6B5B95' # Royal purple
color_map[nn.Conv2d]['fill'] = 'LightSlateGray' # Light Slate Gray
color_map[nn.ReLU]['fill'] = '#87CEFA' # Light Sky Blue
color_map[nn.MaxPool2d]['fill'] = 'LightSeaGreen' # Light Sea Green
color_map[nn.Flatten]['fill'] = '#98FB98' # Pale Green
color_map[nn.Linear]['fill'] = 'LightSteelBlue' # Light Steel Blue

input_shape = (1, 3, 224, 224)
visualtorch.layered_view(model, input_shape=input_shape, color_map=color_map
```

![custom-color](https://github.com/willyfh/visualtorch/assets/5786636/57f28191-d86e-4419-a015-f5fc7fa17b5a)
![custom-color](https://github.com/willyfh/visualtorch/assets/5786636/2e536ffd-8441-4e66-90ff-d152da67363e)

## Contributing

Expand Down
2 changes: 1 addition & 1 deletion visualtorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def graph_view(

current_y = c.y2 + node_spacing

c.fill = color_map.get(type(layer), {}).get("fill", "blue")
c.fill = color_map.get(type(layer), {}).get("fill", "#ADD8E6")
c.outline = color_map.get(type(layer), {}).get("outline", "black")

layer_nodes.append(c)
Expand Down
2 changes: 1 addition & 1 deletion visualtorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, colors: list | None = None):
self.colors = (
colors
if colors is not None
else ["#ffd166", "#ef476f", "#06d6a0", "#118ab2", "#073b4c"]
else ["#FFE4B5", "#ADD8E6", "#98FB98", "#FFA07A", "#D8BFD8"]
)

def get_color(self, class_type: type):
Expand Down

0 comments on commit 0b99f11

Please sign in to comment.