Skip to content

Commit

Permalink
add a simple unit test for graph view (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
willyfh authored Feb 25, 2024
1 parent 0b99f11 commit f81c708
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import torch.nn as nn

from visualtorch import graph_view


@pytest.fixture
def dense_model():
class SimpleDense(nn.Module):
def __init__(self):
super(SimpleDense, self).__init__()
self.h0 = nn.Linear(4, 8)
self.h1 = nn.Linear(8, 8)
self.h2 = nn.Linear(8, 4)
self.out = nn.Linear(4, 2)

def forward(self, x):
x = self.h0(x)
x = self.h1(x)
x = self.h2(x)
x = self.out(x)
return x

model = SimpleDense()
return model


def test_dense_model_graph_view_runs(dense_model):
try:
_ = graph_view(dense_model, input_shape=(1, 4))
except Exception as e:
pytest.fail(f"graph_view raised an exception with a simple dense model: {e}")

0 comments on commit f81c708

Please sign in to comment.