diff --git a/README.md b/README.md index 2c3815e..12f7c00 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,15 @@ -# ⭐ VisualTorch ⭐ +
+

🔥 VisualTorch 🔥

[![python](https://img.shields.io/badge/python-3.10%2B-blue)]() [![pytorch](https://img.shields.io/badge/pytorch-2.0%2B-orange)]() [![Downloads](https://static.pepy.tech/personalized-badge/visualtorch?period=total&units=international_system&left_color=grey&right_color=green&left_text=PyPI%20Downloads)](https://pepy.tech/project/visualtorch) [![Run Tests](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml/badge.svg)](https://github.com/willyfh/visualtorch/actions/workflows/pytest.yml) -**VisualTorch** aims to help visualize Torch-based neural network architectures. Currently, this package supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This package is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary). +
+ +**VisualTorch** aims to help visualize Torch-based neural network architectures. It currently supports generating layered-style and graph-style architectures for PyTorch Sequential and Custom models. This tool is adapted from [visualkeras](https://github.com/paulgavrikov/visualkeras), [pytorchviz](https://github.com/szagoruyko/pytorchviz), and [pytorch-summary](https://github.com/sksq96/pytorch-summary). + +**Note:** VisualTorch may not yet support complex models, but contributions are welcome! + +![layered-and-graph](https://github.com/willyfh/visualtorch/assets/5786636/694e6e6c-58ea-46d6-9280-348337c08ec7) **v0.2**: Added support for custom models and implemented graph view functionality. @@ -52,7 +59,7 @@ input_shape = (1, 3, 224, 224) visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer ``` -![simple-cnn-sequential](https://github.com/willyfh/visualtorch/assets/5786636/9b646fac-c336-4253-ac01-8f3e6b2fcc0b) +![simple-cnn](https://github.com/willyfh/visualtorch/assets/5786636/e8da2a52-66c6-4fda-85f8-7243702fd1f2) ### Custom Model @@ -98,7 +105,7 @@ input_shape = (1, 3, 224, 224) visualtorch.layered_view(model, input_shape=input_shape, legend=True).show() # display using your system viewer ``` -![simple-cnn-custom](https://github.com/willyfh/visualtorch/assets/5786636/f22298b4-f341-4a0d-b85b-11f01e207ad8) +![simple-cnn-custom](https://github.com/willyfh/visualtorch/assets/5786636/9f18db76-838d-4cd1-87ac-3ac5d3509423) ### Graph View @@ -129,7 +136,7 @@ input_shape = (1, 4) visualtorch.graph_view(model, input_shape).show() ``` -![graph-view](https://github.com/willyfh/visualtorch/assets/5786636/a65b4208-72da-497b-b6c9-aafc82b67b58) +![graph](https://github.com/willyfh/visualtorch/assets/5786636/9868f8be-7bfb-4892-ad3b-72de56955c75) ### Save the Image @@ -143,7 +150,7 @@ visualtorch.layered_view(model, input_shape=input_shape, legend=True, to_file='o visualtorch.layered_view(model, input_shape=input_shape, draw_volume=False) ``` -![2d-view](https://github.com/willyfh/visualtorch/assets/5786636/5b16c252-f589-4b3f-8ea4-1bc188e6c124) +![2d-view](https://github.com/willyfh/visualtorch/assets/5786636/71848bfa-5447-4e66-bf4c-84f9e51a581e) ### Custom Color @@ -153,17 +160,17 @@ Use 'fill' to change the color of the layer, and use 'outline' to change the col from collections import defaultdict color_map = defaultdict(dict) -color_map[nn.Conv2d]['fill'] = '#FF6F61' # Coral red -color_map[nn.ReLU]['fill'] = 'skyblue' -color_map[nn.MaxPool2d]['fill'] = '#88B04B' # Sage green -color_map[nn.Flatten]['fill'] = 'gold' -color_map[nn.Linear]['fill'] = '#6B5B95' # Royal purple +color_map[nn.Conv2d]['fill'] = 'LightSlateGray' # Light Slate Gray +color_map[nn.ReLU]['fill'] = '#87CEFA' # Light Sky Blue +color_map[nn.MaxPool2d]['fill'] = 'LightSeaGreen' # Light Sea Green +color_map[nn.Flatten]['fill'] = '#98FB98' # Pale Green +color_map[nn.Linear]['fill'] = 'LightSteelBlue' # Light Steel Blue input_shape = (1, 3, 224, 224) visualtorch.layered_view(model, input_shape=input_shape, color_map=color_map ``` -![custom-color](https://github.com/willyfh/visualtorch/assets/5786636/57f28191-d86e-4419-a015-f5fc7fa17b5a) +![custom-color](https://github.com/willyfh/visualtorch/assets/5786636/2e536ffd-8441-4e66-90ff-d152da67363e) ## Contributing diff --git a/visualtorch/graph.py b/visualtorch/graph.py index e0e3d97..8ed844b 100644 --- a/visualtorch/graph.py +++ b/visualtorch/graph.py @@ -121,7 +121,7 @@ def graph_view( current_y = c.y2 + node_spacing - c.fill = color_map.get(type(layer), {}).get("fill", "blue") + c.fill = color_map.get(type(layer), {}).get("fill", "#ADD8E6") c.outline = color_map.get(type(layer), {}).get("outline", "black") layer_nodes.append(c) diff --git a/visualtorch/utils.py b/visualtorch/utils.py index 18f4b0a..01e234b 100644 --- a/visualtorch/utils.py +++ b/visualtorch/utils.py @@ -152,7 +152,7 @@ def __init__(self, colors: list | None = None): self.colors = ( colors if colors is not None - else ["#ffd166", "#ef476f", "#06d6a0", "#118ab2", "#073b4c"] + else ["#FFE4B5", "#ADD8E6", "#98FB98", "#FFA07A", "#D8BFD8"] ) def get_color(self, class_type: type):