From 3a708d258172885d865c8a6f0f2706929943afda Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 25 Aug 2022 14:34:26 +0200 Subject: [PATCH] New TryExcept decorator (#9154) * New TryExcept decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- utils/__init__.py | 27 ++++++++++++++++++ utils/general.py | 27 ++---------------- utils/metrics.py | 73 ++++++++++++++++++++++++----------------------- utils/plots.py | 5 ++-- 4 files changed, 71 insertions(+), 61 deletions(-) diff --git a/utils/__init__.py b/utils/__init__.py index a63c473a4340..7466a486caf4 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -3,6 +3,33 @@ utils/initialization """ +import contextlib +import threading + + +class TryExcept(contextlib.ContextDecorator): + # YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager + def __init__(self, msg='default message here'): + self.msg = msg + + def __enter__(self): + pass + + def __exit__(self, exc_type, value, traceback): + if value: + print(f'{self.msg}: {value}') + return True + + +def threaded(func): + # Multi-threads a target function and returns thread. Usage: @threaded decorator + def wrapper(*args, **kwargs): + thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) + thread.start() + return thread + + return wrapper + def notebook_init(verbose=True): # Check system software and hardware diff --git a/utils/general.py b/utils/general.py index d8c90f10ac8f..91b13f84a6c4 100755 --- a/utils/general.py +++ b/utils/general.py @@ -15,7 +15,6 @@ import shutil import signal import sys -import threading import time import urllib from datetime import datetime @@ -34,6 +33,7 @@ import torchvision import yaml +from utils import TryExcept from utils.downloads import gsutil_getsize from utils.metrics import box_iou, fitness @@ -195,27 +195,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): os.chdir(self.cwd) -def try_except(func): - # try-except function. Usage: @try_except decorator - def handler(*args, **kwargs): - try: - func(*args, **kwargs) - except Exception as e: - print(e) - - return handler - - -def threaded(func): - # Multi-threads a target function and returns thread. Usage: @threaded decorator - def wrapper(*args, **kwargs): - thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) - thread.start() - return thread - - return wrapper - - def methods(instance): # Get class/instance methods return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] @@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory return '' -@try_except +@TryExcept() @WorkingDirectory(ROOT) def check_git_status(repo='ultralytics/yolov5'): # YOLOv5 status check, recommend 'git pull' if code is out of date @@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals return result -@try_except +@TryExcept() def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()): # Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages) prefix = colorstr('red', 'bold', 'requirements:') diff --git a/utils/metrics.py b/utils/metrics.py index 8fa3c7e217c7..de1bf05b326b 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -11,6 +11,8 @@ import numpy as np import torch +from utils import TryExcept, threaded + def fitness(x): # Model fitness as a weighted combination of metrics @@ -184,36 +186,35 @@ def tp_fp(self): # fn = self.matrix.sum(0) - tp # false negatives (missed detections) return tp[:-1], fp[:-1] # remove background class + @TryExcept('WARNING: ConfusionMatrix plot failure') def plot(self, normalize=True, save_dir='', names=()): - try: - import seaborn as sn - - array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns - array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) - - fig = plt.figure(figsize=(12, 9), tight_layout=True) - nc, nn = self.nc, len(names) # number of classes, names - sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size - labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels - with warnings.catch_warnings(): - warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered - sn.heatmap(array, - annot=nc < 30, - annot_kws={ - "size": 8}, - cmap='Blues', - fmt='.2f', - square=True, - vmin=0.0, - xticklabels=names + ['background FP'] if labels else "auto", - yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) - fig.axes[0].set_xlabel('True') - fig.axes[0].set_ylabel('Predicted') - plt.title('Confusion Matrix') - fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) - plt.close() - except Exception as e: - print(f'WARNING: ConfusionMatrix plot failure: {e}') + import seaborn as sn + + array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns + array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) + + fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True) + nc, nn = self.nc, len(names) # number of classes, names + sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size + labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered + sn.heatmap(array, + ax=ax, + annot=nc < 30, + annot_kws={ + "size": 8}, + cmap='Blues', + fmt='.2f', + square=True, + vmin=0.0, + xticklabels=names + ['background FP'] if labels else "auto", + yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) + ax.set_ylabel('True') + ax.set_ylabel('Predicted') + ax.set_title('Confusion Matrix') + fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) + plt.close(fig) def print(self): for i in range(self.nc + 1): @@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7): # Plots ---------------------------------------------------------------------------------------------------------------- +@threaded def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): # Precision-recall curve fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) @@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): ax.set_ylabel('Precision') ax.set_xlim(0, 1) ax.set_ylim(0, 1) - plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") - plt.title('Precision-Recall Curve') + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title('Precision-Recall Curve') fig.savefig(save_dir, dpi=250) - plt.close() + plt.close(fig) +@threaded def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'): # Metric-confidence curve fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) @@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi ax.set_ylabel(ylabel) ax.set_xlim(0, 1) ax.set_ylim(0, 1) - plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") - plt.title(f'{ylabel}-Confidence Curve') + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title(f'{ylabel}-Confidence Curve') fig.savefig(save_dir, dpi=250) - plt.close() + plt.close(fig) diff --git a/utils/plots.py b/utils/plots.py index d35e2bdd168a..2aa163268336 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -19,8 +19,9 @@ import torch from PIL import Image, ImageDraw, ImageFont +from utils import TryExcept, threaded from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path, - is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh) + is_ascii, xywh2xyxy, xyxy2xywh) from utils.metrics import fitness # Settings @@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_ plt.savefig(f, dpi=300) -@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395 +@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395 def plot_labels(labels, names=(), save_dir=Path('')): # plot dataset labels LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")