diff --git a/tutorial.ipynb b/tutorial.ipynb index 9440ca8b1788..4ce87c75aa64 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -1104,4 +1104,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/utils/plots.py b/utils/plots.py index 9919e4d9d88f..69037ee9af70 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -132,7 +132,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec if 'Detect' not in module_type: batch, channels, height, width = x.shape # batch, channels, height, width if height > 1 and width > 1: - f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename + f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels n = min(n, channels) # number of plots @@ -143,9 +143,10 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec ax[i].imshow(blocks[i].squeeze()) # cmap='gray' ax[i].axis('off') - print(f'Saving {save_dir / f}... ({n}/{channels})') - plt.savefig(save_dir / f, dpi=300, bbox_inches='tight') + print(f'Saving {f}... ({n}/{channels})') + plt.savefig(f, dpi=300, bbox_inches='tight') plt.close() + np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save def hist2d(x, y, n=100):