Skip to content

Commit

Permalink
New TryExcept decorator (#9154)
Browse files Browse the repository at this point in the history
* New TryExcept decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed Aug 25, 2022
1 parent f0e5a60 commit d07ddc6
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 61 deletions.
27 changes: 27 additions & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,33 @@
utils/initialization
"""

import contextlib
import threading


class TryExcept(contextlib.ContextDecorator):
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
def __init__(self, msg='default message here'):
self.msg = msg

def __enter__(self):
pass

def __exit__(self, exc_type, value, traceback):
if value:
print(f'{self.msg}: {value}')
return True


def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread

return wrapper


def notebook_init(verbose=True):
# Check system software and hardware
Expand Down
27 changes: 3 additions & 24 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import shutil
import signal
import sys
import threading
import time
import urllib
from datetime import datetime
Expand All @@ -34,6 +33,7 @@
import torchvision
import yaml

from utils import TryExcept
from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness

Expand Down Expand Up @@ -195,27 +195,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
os.chdir(self.cwd)


def try_except(func):
# try-except function. Usage: @try_except decorator
def handler(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
print(e)

return handler


def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread

return wrapper


def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
Expand Down Expand Up @@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory
return ''


@try_except
@TryExcept()
@WorkingDirectory(ROOT)
def check_git_status(repo='ultralytics/yolov5'):
# YOLOv5 status check, recommend 'git pull' if code is out of date
Expand Down Expand Up @@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
return result


@try_except
@TryExcept()
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
prefix = colorstr('red', 'bold', 'requirements:')
Expand Down
73 changes: 38 additions & 35 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
import torch

from utils import TryExcept, threaded


def fitness(x):
# Model fitness as a weighted combination of metrics
Expand Down Expand Up @@ -184,36 +186,35 @@ def tp_fp(self):
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class

@TryExcept('WARNING: ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()):
try:
import seaborn as sn

array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig = plt.figure(figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
annot=nc < 30,
annot_kws={
"size": 8},
cmap='Blues',
fmt='.2f',
square=True,
vmin=0.0,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
plt.title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close()
except Exception as e:
print(f'WARNING: ConfusionMatrix plot failure: {e}')
import seaborn as sn

array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
ax=ax,
annot=nc < 30,
annot_kws={
"size": 8},
cmap='Blues',
fmt='.2f',
square=True,
vmin=0.0,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
ax.set_ylabel('True')
ax.set_ylabel('Predicted')
ax.set_title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close(fig)

def print(self):
for i in range(self.nc + 1):
Expand Down Expand Up @@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
# Plots ----------------------------------------------------------------------------------------------------------------


@threaded
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
Expand All @@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title('Precision-Recall Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250)
plt.close()
plt.close(fig)


@threaded
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
Expand All @@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title(f'{ylabel}-Confidence Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250)
plt.close()
plt.close(fig)
5 changes: 3 additions & 2 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import torch
from PIL import Image, ImageDraw, ImageFont

from utils import TryExcept, threaded
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
is_ascii, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
plt.savefig(f, dpi=300)


@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
Expand Down

0 comments on commit d07ddc6

Please sign in to comment.