From c3aa6dd45117bf99e78e05fc564fd1596fa5e45b Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 24 Feb 2024 23:57:33 +0900 Subject: [PATCH 1/2] remove default input shape --- visualtorch/layered.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/visualtorch/layered.py b/visualtorch/layered.py index 455283f..3517508 100644 --- a/visualtorch/layered.py +++ b/visualtorch/layered.py @@ -24,7 +24,7 @@ def layered_view( model, - input_shape=(1, 3, 224, 224), + input_shape, to_file: str | None = None, min_z: int = 10, min_xy: int = 10, From a844d86a56d5821e194e9cff843f553bf82f7760 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 24 Feb 2024 23:59:33 +0900 Subject: [PATCH 2/2] add input shape to the tests --- tests/test_layered.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_layered.py b/tests/test_layered.py index 31f24b4..84dada0 100644 --- a/tests/test_layered.py +++ b/tests/test_layered.py @@ -55,20 +55,20 @@ def forward(self, x): def test_sequential_model_layered_view_runs(sequential_model): try: - _ = layered_view(sequential_model) + _ = layered_view(sequential_model, input_shape=(1, 3, 224, 224)) except Exception as e: pytest.fail(f"layered_view raised an exception with Sequential model: {e}") def test_module_list_model_layered_view_runs(module_list_model): try: - _ = layered_view(module_list_model) + _ = layered_view(module_list_model, input_shape=(1, 3, 224, 224)) except Exception as e: pytest.fail(f"layered_view raised an exception with ModuleList model: {e}") def test_custom_model_layered_view_runs(custom_model): try: - _ = layered_view(custom_model) + _ = layered_view(custom_model, input_shape=(1, 3, 224, 224)) except Exception as e: pytest.fail(f"layered_view raised an exception with Custom model: {e}")