Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Nov 24, 2021
1 parent 711a664 commit 07203d9
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 151 deletions.
3 changes: 2 additions & 1 deletion yolort/v5/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from .models import AutoShape
from .models.yolo import Model
from .utils import attempt_download, intersect_dicts, set_logging

Expand Down Expand Up @@ -68,6 +69,6 @@ def load_yolov5_model(checkpoint_path: str, autoshape: bool = False, verbose: bo
model.load_state_dict(ckpt_state_dict, strict=False)

if autoshape:
model = model.autoshape()
model = AutoShape(model)

return model
54 changes: 33 additions & 21 deletions yolort/v5/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
xyxy2xywh,
)
from yolort.v5.utils.plots import Annotator, colors, save_one_box
from yolort.v5.utils.torch_utils import time_sync
from yolort.v5.utils.torch_utils import copy_attr, time_sync

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -419,32 +419,44 @@ class AutoShape(nn.Module):

conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class
# (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
classes = None
multi_label = False # NMS multiple labels per box
max_det = 1000 # maximum number of detections per image

def __init__(self, model):
super().__init__()
LOGGER.info("Adding AutoShape... ")
# copy attributes
copy_attr(self, model, include=("yaml", "nc", "hyp", "names", "stride", "abc"), exclude=())
self.model = model.eval()

def autoshape(self):
# model already converted to model.autoshape()
LOGGER.info("AutoShape already enabled, skipping... ")
def _apply(self, fn):
"""
Apply to(), cpu(), cuda(), half() to model tensors that
are not parameters or registered buffers
"""
self = super()._apply(fn)
m = self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self

@torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False):
"""
Inference from various sources. For height=640, width=1280, RGB images example inputs are:
- file: imgs = 'data/images/zidane.jpg' # str or PosixPath
- URI: = 'https://ultralytics.com/images/zidane.jpg'
- OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
- PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
- numpy: = np.zeros((640,1280,3)) # HWC
- torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
- multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
"""
from yolort.v5.utils.augmentations import letterbox

# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
# numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images

from yolort.v5.utils.datasets import exif_transpose

t = [time_sync()]
Expand All @@ -454,10 +466,10 @@ def forward(self, imgs, size=640, augment=False, profile=False):
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference

# Pre-process
n, imgs = (
(len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs])
) # number of images, list of images
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
# number of images, list of images
n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs])
# image and inference shapes, filenames
shape0, shape1, files = [], [], []
for i, im in enumerate(imgs):
f = f"image{i}" # filename
if isinstance(im, (str, Path)): # filename or uri
Expand All @@ -482,7 +494,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
t.append(time_sync())

with amp.autocast(enabled=p.device.type != "cpu"):
Expand All @@ -498,7 +510,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
classes=self.classes,
multi_label=self.multi_label,
max_det=self.max_det,
) # NMS
)
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])

Expand Down
Loading

0 comments on commit 07203d9

Please sign in to comment.