From e3ff7806769444de864060494d1be8e18ce046a1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 17 Oct 2022 14:34:33 +0200 Subject: [PATCH] Allow PyTorch Hub results to display in notebooks (#9825) * Allow PyTorch Hub results to display in notebooks * fix CI * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * fix CI * fix CI * fix CI * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI * fix CI * fix CI Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- classify/predict.py | 2 +- detect.py | 2 +- models/common.py | 13 +++++++++---- segment/predict.py | 2 +- utils/__init__.py | 2 +- utils/autoanchor.py | 2 +- utils/general.py | 17 +++++++++++++---- utils/metrics.py | 2 +- 8 files changed, 28 insertions(+), 14 deletions(-) diff --git a/classify/predict.py b/classify/predict.py index 9114aab1d703..9373649bf27d 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -91,7 +91,7 @@ def run( # Dataloader bs = 1 # batch_size if webcam: - view_img = check_imshow() + view_img = check_imshow(warn=True) dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride) bs = len(dataset) elif screenshot: diff --git a/detect.py b/detect.py index 8f48d8d28000..98af7235ea69 100644 --- a/detect.py +++ b/detect.py @@ -99,7 +99,7 @@ def run( # Dataloader bs = 1 # batch_size if webcam: - view_img = check_imshow() + view_img = check_imshow(warn=True) dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) bs = len(dataset) elif screenshot: diff --git a/models/common.py b/models/common.py index d889d0292c61..e6da429de3e5 100644 --- a/models/common.py +++ b/models/common.py @@ -18,16 +18,20 @@ import requests import torch import torch.nn as nn +from IPython.display import display from PIL import Image from torch.cuda import amp +from utils import TryExcept from utils.dataloaders import exif_transpose, letterbox -from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr, - increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy, xyxy2xywh, - yaml_load) +from utils.general import (LOGGER, ROOT, Profile, check_imshow, check_requirements, check_suffix, check_version, + colorstr, increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy, + xyxy2xywh, yaml_load) from utils.plots import Annotator, colors, save_one_box from utils.torch_utils import copy_attr, smart_inference_mode +CHECK_IMSHOW = check_imshow() + def autopad(k, p=None, d=1): # kernel, padding, dilation # Pad to 'same' shape outputs @@ -756,7 +760,7 @@ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, l im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np if show: - im.show(self.files[i]) # show + im.show(self.files[i]) if CHECK_IMSHOW else display(im) if save: f = self.files[i] im.save(save_dir / f) # save @@ -772,6 +776,7 @@ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, l LOGGER.info(f'Saved results to {save_dir}\n') return crops + @TryExcept('Showing images is not supported in this environment') def show(self, labels=True): self._run(show=True, labels=labels) # show results diff --git a/segment/predict.py b/segment/predict.py index 94117cd78633..44d6d3904c19 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -102,7 +102,7 @@ def run( # Dataloader bs = 1 # batch_size if webcam: - view_img = check_imshow() + view_img = check_imshow(warn=True) dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) bs = len(dataset) elif screenshot: diff --git a/utils/__init__.py b/utils/__init__.py index 8403a6149827..0afe6f475625 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -23,7 +23,7 @@ def __enter__(self): def __exit__(self, exc_type, value, traceback): if value: - print(emojis(f'{self.msg}{value}')) + print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) return True diff --git a/utils/autoanchor.py b/utils/autoanchor.py index 7e7e9985d68a..cfc4c276e3aa 100644 --- a/utils/autoanchor.py +++ b/utils/autoanchor.py @@ -26,7 +26,7 @@ def check_anchor_order(m): m.anchors[:] = m.anchors.flip(0) -@TryExcept(f'{PREFIX}ERROR: ') +@TryExcept(f'{PREFIX}ERROR') def check_anchors(dataset, model, thr=4.0, imgsz=640): # Check anchor fit to data, recompute if necessary m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() diff --git a/utils/general.py b/utils/general.py index d9d54d9e4f71..76bc0b1d7a79 100644 --- a/utils/general.py +++ b/utils/general.py @@ -27,6 +27,7 @@ from zipfile import ZipFile import cv2 +import IPython import numpy as np import pandas as pd import pkg_resources as pkg @@ -73,6 +74,12 @@ def is_colab(): return 'COLAB_GPU' in os.environ +def is_notebook(): + # Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace + ipython_type = str(type(IPython.get_ipython())) + return 'colab' in ipython_type or 'zmqshell' in ipython_type + + def is_kaggle(): # Is environment a Kaggle Notebook? return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com' @@ -383,18 +390,20 @@ def check_img_size(imgsz, s=32, floor=0): return new_size -def check_imshow(): +def check_imshow(warn=False): # Check if environment supports image displays try: - assert not is_docker(), 'cv2.imshow() is disabled in Docker environments' - assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments' + assert not is_notebook() + assert not is_docker() + assert 'NoneType' not in str(type(IPython.get_ipython())) # SSH terminals, GitHub CI cv2.imshow('test', np.zeros((1, 1, 3))) cv2.waitKey(1) cv2.destroyAllWindows() cv2.waitKey(1) return True except Exception as e: - LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}') + if warn: + LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}') return False diff --git a/utils/metrics.py b/utils/metrics.py index ed611d7d38fa..f0bc787e1518 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -186,7 +186,7 @@ def tp_fp(self): # fn = self.matrix.sum(0) - tp # false negatives (missed detections) return tp[:-1], fp[:-1] # remove background class - @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure: ') + @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure') def plot(self, normalize=True, save_dir='', names=()): import seaborn as sn