Skip to content

Commit

Permalink
if plots
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Nov 22, 2020
1 parent 36c6cdb commit 4faf5dc
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def test(data,
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--img-size', type=int, default=256, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS')
parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
plot_results(save_dir=save_dir) # save as results.png
if wandb:
files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files]})
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
else:
dist.destroy_process_group()
Expand Down
4 changes: 2 additions & 2 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def plot(self, save_dir='', names=()):
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig, ax = plt.figure(figsize=(12, 9))
fig = plt.figure(figsize=(12, 9))
# sn.set(font_scale=1.0) # for label size
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, ax=ax,
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FN'] if names else "auto",
yticklabels=names + ['background FP'] if names else "auto").set_facecolor((1, 1, 1))
fig.tight_layout()
Expand Down

0 comments on commit 4faf5dc

Please sign in to comment.