Skip to content

Commit

Permalink
Add EXIF rotation to YOLOv5 Hub inference (#3852)
Browse files Browse the repository at this point in the history
* rotating an image according to its exif tag

* Update common.py

* Update datasets.py

* Update datasets.py

faster

* delete extraneous gpg file

* Update common.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
vaaliferov and glenn-jocher committed Jul 2, 2021
1 parent 4717a3b commit 831773f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
9 changes: 5 additions & 4 deletions models/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# YOLOv5 common modules

import math
from copy import copy
from pathlib import Path

import math
import numpy as np
import pandas as pd
import requests
Expand All @@ -12,7 +12,7 @@
from PIL import Image
from torch.cuda import amp

from utils.datasets import letterbox
from utils.datasets import exif_transpose, letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import time_synchronized
Expand Down Expand Up @@ -252,9 +252,10 @@ def forward(self, imgs, size=640, augment=False, profile=False):
for i, im in enumerate(imgs):
f = f'image{i}' # filename
if isinstance(im, str): # filename or uri
im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
im, f = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im), im
im = np.asarray(exif_transpose(im))
elif isinstance(im, Image.Image): # PIL Image
im, f = np.asarray(im), getattr(im, 'filename', f) or f
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename') or f
files.append(Path(f).with_suffix('.jpg').name)
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
Expand Down
26 changes: 26 additions & 0 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,32 @@ def exif_size(img):
return s


def exif_transpose(image):
"""
Transpose a PIL image accordingly if it has an EXIF Orientation tag.
From https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
:param image: The image to transpose.
:return: An image.
"""
exif = image.getexif()
orientation = exif.get(0x0112, 1) # default 1
if orientation > 1:
method = {2: Image.FLIP_LEFT_RIGHT,
3: Image.ROTATE_180,
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90,
}.get(orientation)
if method is not None:
image = image.transpose(method)
del exif[0x0112]
image.info["exif"] = exif.tobytes()
return image


def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
Expand Down

0 comments on commit 831773f

Please sign in to comment.