Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow PyTorch Hub results to display in notebooks #9825

Merged
merged 20 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion classify/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run(
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow()
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
Expand Down
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def run(
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow()
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
Expand Down
13 changes: 9 additions & 4 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
import requests
import torch
import torch.nn as nn
from IPython.display import display
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't ipython be an optional dependency ? IMO, it would make sense to remove it for a minimal setup without "displaying" results.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 ipython is only used for checking if environment is notebook, clearing notebook cells and showing images in notebook cells. We have in requirements due to the popularity of your Colab notebook.

Does it conflict with anything or is it installing significant other dependencies in your environment?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the response, @glenn-jocher

Does it conflict with anything or is it installing significant other dependencies in your environment?

well, in terms of space, ipython + deps does not take much space (~1-2MB), I agree. I start to understand better all the context of this work. For demos in colab etc it totally makes sense. In terms of headless server usage (probably less frequent?), it feels weird to install ipython, matplotlib, seaborn (it is just my opinion).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 got it. Much of this is the difference between training support, visualization support, export support etc. We try to strike the right balance but it's impossible to please everyone.

from PIL import Image
from torch.cuda import amp

from utils import TryExcept
from utils.dataloaders import exif_transpose, letterbox
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy, xyxy2xywh,
yaml_load)
from utils.general import (LOGGER, ROOT, Profile, check_imshow, check_requirements, check_suffix, check_version,
colorstr, increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy,
xyxy2xywh, yaml_load)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, smart_inference_mode

CHECK_IMSHOW = check_imshow()


def autopad(k, p=None, d=1): # kernel, padding, dilation
# Pad to 'same' shape outputs
Expand Down Expand Up @@ -756,7 +760,7 @@ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, l

im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if show:
im.show(self.files[i]) # show
im.show(self.files[i]) if CHECK_IMSHOW else display(im)
if save:
f = self.files[i]
im.save(save_dir / f) # save
Expand All @@ -772,6 +776,7 @@ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, l
LOGGER.info(f'Saved results to {save_dir}\n')
return crops

@TryExcept('Showing images is not supported in this environment')
def show(self, labels=True):
self._run(show=True, labels=labels) # show results

Expand Down
2 changes: 1 addition & 1 deletion segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run(
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow()
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
Expand Down
2 changes: 1 addition & 1 deletion utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __enter__(self):

def __exit__(self, exc_type, value, traceback):
if value:
print(emojis(f'{self.msg}{value}'))
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
return True


Expand Down
2 changes: 1 addition & 1 deletion utils/autoanchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def check_anchor_order(m):
m.anchors[:] = m.anchors.flip(0)


@TryExcept(f'{PREFIX}ERROR: ')
@TryExcept(f'{PREFIX}ERROR')
def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
Expand Down
17 changes: 13 additions & 4 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from zipfile import ZipFile

import cv2
import IPython
import numpy as np
import pandas as pd
import pkg_resources as pkg
Expand Down Expand Up @@ -73,6 +74,12 @@ def is_colab():
return 'COLAB_GPU' in os.environ


def is_notebook():
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
ipython_type = str(type(IPython.get_ipython()))
return 'colab' in ipython_type or 'zmqshell' in ipython_type


def is_kaggle():
# Is environment a Kaggle Notebook?
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
Expand Down Expand Up @@ -383,18 +390,20 @@ def check_img_size(imgsz, s=32, floor=0):
return new_size


def check_imshow():
def check_imshow(warn=False):
# Check if environment supports image displays
try:
assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
assert not is_notebook()
assert not is_docker()
assert 'NoneType' not in str(type(IPython.get_ipython())) # SSH terminals, GitHub CI
cv2.imshow('test', np.zeros((1, 1, 3)))
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
if warn:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
return False


Expand Down
2 changes: 1 addition & 1 deletion utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def tp_fp(self):
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class

@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure: ')
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()):
import seaborn as sn

Expand Down