Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New TryExcept decorator #9154

Merged
merged 2 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,33 @@
utils/initialization
"""

import contextlib
import threading


class TryExcept(contextlib.ContextDecorator):
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
def __init__(self, msg='default message here'):
self.msg = msg

def __enter__(self):
pass

def __exit__(self, exc_type, value, traceback):
if value:
print(f'{self.msg}: {value}')
return True


def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread

return wrapper


def notebook_init(verbose=True):
# Check system software and hardware
Expand Down
27 changes: 3 additions & 24 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import shutil
import signal
import sys
import threading
import time
import urllib
from datetime import datetime
Expand All @@ -34,6 +33,7 @@
import torchvision
import yaml

from utils import TryExcept
from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness

Expand Down Expand Up @@ -195,27 +195,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
os.chdir(self.cwd)


def try_except(func):
# try-except function. Usage: @try_except decorator
def handler(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
print(e)

return handler


def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread

return wrapper


def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
Expand Down Expand Up @@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory
return ''


@try_except
@TryExcept()
@WorkingDirectory(ROOT)
def check_git_status(repo='ultralytics/yolov5'):
# YOLOv5 status check, recommend 'git pull' if code is out of date
Expand Down Expand Up @@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
return result


@try_except
@TryExcept()
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
prefix = colorstr('red', 'bold', 'requirements:')
Expand Down
73 changes: 38 additions & 35 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
import torch

from utils import TryExcept, threaded


def fitness(x):
# Model fitness as a weighted combination of metrics
Expand Down Expand Up @@ -184,36 +186,35 @@ def tp_fp(self):
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class

@TryExcept('WARNING: ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()):
try:
import seaborn as sn

array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig = plt.figure(figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
annot=nc < 30,
annot_kws={
"size": 8},
cmap='Blues',
fmt='.2f',
square=True,
vmin=0.0,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
plt.title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close()
except Exception as e:
print(f'WARNING: ConfusionMatrix plot failure: {e}')
import seaborn as sn

array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
ax=ax,
annot=nc < 30,
annot_kws={
"size": 8},
cmap='Blues',
fmt='.2f',
square=True,
vmin=0.0,
xticklabels=names + ['background FP'] if labels else "auto",
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
ax.set_ylabel('True')
ax.set_ylabel('Predicted')
ax.set_title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close(fig)

def print(self):
for i in range(self.nc + 1):
Expand Down Expand Up @@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
# Plots ----------------------------------------------------------------------------------------------------------------


@threaded
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
Expand All @@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title('Precision-Recall Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250)
plt.close()
plt.close(fig)


@threaded
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
Expand All @@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title(f'{ylabel}-Confidence Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250)
plt.close()
plt.close(fig)
5 changes: 3 additions & 2 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import torch
from PIL import Image, ImageDraw, ImageFont

from utils import TryExcept, threaded
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
is_ascii, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
plt.savefig(f, dpi=300)


@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
Expand Down