From f17c86b7f0d2038288d7292cb82dec2433cc91e5 Mon Sep 17 00:00:00 2001 From: Zengyf-CVer <41098760+Zengyf-CVer@users.noreply.github.com> Date: Mon, 22 Nov 2021 03:21:44 +0800 Subject: [PATCH] Save *.npy features on detect.py `--visualize` (#5701) * Add feature map to save npy files Add feature map to save npy files,export npy files with 32 feature maps per layer. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plots.py * Update plots.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plots.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- tutorial.ipynb | 2 +- utils/plots.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tutorial.ipynb b/tutorial.ipynb index 9440ca8b1788..4ce87c75aa64 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -1104,4 +1104,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/utils/plots.py b/utils/plots.py index 9919e4d9d88f..69037ee9af70 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -132,7 +132,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec if 'Detect' not in module_type: batch, channels, height, width = x.shape # batch, channels, height, width if height > 1 and width > 1: - f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename + f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels n = min(n, channels) # number of plots @@ -143,9 +143,10 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec ax[i].imshow(blocks[i].squeeze()) # cmap='gray' ax[i].axis('off') - print(f'Saving {save_dir / f}... ({n}/{channels})') - plt.savefig(save_dir / f, dpi=300, bbox_inches='tight') + print(f'Saving {f}... ({n}/{channels})') + plt.savefig(f, dpi=300, bbox_inches='tight') plt.close() + np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save def hist2d(x, y, n=100):