diff --git a/export.py b/export.py index 437616a9890d..21c83c697b4d 100644 --- a/export.py +++ b/export.py @@ -411,7 +411,7 @@ def parse_opt(): parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization') parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') - parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version') + parser.add_argument('--opset', type=int, default=14, help='ONNX: opset version') parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') diff --git a/hubconf.py b/hubconf.py index 03335f7906f0..e407677b3233 100644 --- a/hubconf.py +++ b/hubconf.py @@ -5,6 +5,7 @@ Usage: import torch model = torch.hub.load('ultralytics/yolov5', 'yolov5s') + model = torch.hub.load('ultralytics/yolov5:master', 'custom', 'path/to/yolov5s.onnx') # file from branch """ import torch @@ -27,26 +28,25 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo """ from pathlib import Path - from models.common import AutoShape - from models.experimental import attempt_load + from models.common import AutoShape, DetectMultiBackend from models.yolo import Model from utils.downloads import attempt_download from utils.general import check_requirements, intersect_dicts, set_logging from utils.torch_utils import select_device - file = Path(__file__).resolve() check_requirements(exclude=('tensorboard', 'thop', 'opencv-python')) set_logging(verbose=verbose) - save_dir = Path('') if str(name).endswith('.pt') else file.parent - path = (save_dir / name).with_suffix('.pt') # checkpoint path + name = Path(name) + path = name.with_suffix('.pt') if name.suffix == '' else name # checkpoint path try: device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device) if pretrained and channels == 3 and classes == 80: - model = attempt_load(path, map_location=device) # download/load FP32 model + model = DetectMultiBackend(path, device=device) # download/load FP32 model + # model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model else: - cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path + cfg = list((Path(__file__).parent / 'models').rglob(f'{path.name}.yaml'))[0] # model.yaml path model = Model(cfg, channels, classes) # create model if pretrained: ckpt = torch.load(attempt_download(path), map_location=device) # load diff --git a/models/common.py b/models/common.py index 73f21729fa85..6a5303ba8c42 100644 --- a/models/common.py +++ b/models/common.py @@ -276,7 +276,7 @@ def forward(self, x): class DetectMultiBackend(nn.Module): # YOLOv5 MultiBackend class for python inference on various backends - def __init__(self, weights='yolov5s.pt', device=None, dnn=True): + def __init__(self, weights='yolov5s.pt', device=None, dnn=False): # Usage: # PyTorch: weights = *.pt # TorchScript: *.torchscript @@ -287,6 +287,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True): # ONNX Runtime: *.onnx # OpenCV DNN: *.onnx with dnn=True # TensorRT: *.engine + from models.experimental import attempt_download, attempt_load # scoped to avoid circular import + super().__init__() w = str(weights[0] if isinstance(weights, list) else weights) suffix = Path(w).suffix.lower() @@ -294,6 +296,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True): check_suffix(w, suffixes) # check weights have acceptable suffix pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults + attempt_download(w) # download if not local if jit: # TorchScript LOGGER.info(f'Loading {w} for TorchScript inference...') @@ -303,11 +306,12 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True): d = json.loads(extra_files['config.txt']) # extra_files dict stride, names = int(d['stride']), d['names'] elif pt: # PyTorch - from models.experimental import attempt_load # scoped to avoid circular import model = attempt_load(weights, map_location=device) stride = int(model.stride.max()) # model stride names = model.module.names if hasattr(model, 'module') else model.names # get class names + self.model = model # explicitly assign for to(), cpu(), cuda(), half() elif coreml: # CoreML + LOGGER.info(f'Loading {w} for CoreML inference...') import coremltools as ct model = ct.models.MLModel(w) elif dnn: # ONNX OpenCV DNN @@ -316,7 +320,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True): net = cv2.dnn.readNetFromONNX(w) elif onnx: # ONNX Runtime LOGGER.info(f'Loading {w} for ONNX Runtime inference...') - check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime')) + check_requirements(('onnx', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime')) import onnxruntime session = onnxruntime.InferenceSession(w, None) elif engine: # TensorRT @@ -376,7 +380,7 @@ def forward(self, im, augment=False, visualize=False, val=False): if self.pt: # PyTorch y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize) return y if val else y[0] - elif self.coreml: # CoreML *.mlmodel + elif self.coreml: # CoreML im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) im = Image.fromarray((im[0] * 255).astype('uint8')) # im = im.resize((192, 320), Image.ANTIALIAS) @@ -433,24 +437,28 @@ class AutoShape(nn.Module): # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS conf = 0.25 # NMS confidence threshold iou = 0.45 # NMS IoU threshold - classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs + agnostic = False # NMS class-agnostic multi_label = False # NMS multiple labels per box + classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs max_det = 1000 # maximum number of detections per image def __init__(self, model): super().__init__() LOGGER.info('Adding AutoShape... ') copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes + self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance + self.pt = not self.dmb or model.pt # PyTorch model self.model = model.eval() def _apply(self, fn): # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers self = super()._apply(fn) - m = self.model.model[-1] # Detect() - m.stride = fn(m.stride) - m.grid = list(map(fn, m.grid)) - if isinstance(m.anchor_grid, list): - m.anchor_grid = list(map(fn, m.anchor_grid)) + if self.pt: + m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect() + m.stride = fn(m.stride) + m.grid = list(map(fn, m.grid)) + if isinstance(m.anchor_grid, list): + m.anchor_grid = list(map(fn, m.anchor_grid)) return self @torch.no_grad() @@ -465,7 +473,7 @@ def forward(self, imgs, size=640, augment=False, profile=False): # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images t = [time_sync()] - p = next(self.model.parameters()) # for device and type + p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type if isinstance(imgs, torch.Tensor): # torch with amp.autocast(enabled=p.device.type != 'cpu'): return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference @@ -489,8 +497,8 @@ def forward(self, imgs, size=640, augment=False, profile=False): g = (size / max(s)) # gain shape1.append([y * g for y in s]) imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update - shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape - x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad + shape1 = [make_divisible(x, self.stride) for x in np.stack(shape1, 0).max(0)] # inference shape + x = [letterbox(im, new_shape=shape1 if self.pt else size, auto=False)[0] for im in imgs] # pad x = np.stack(x, 0) if n > 1 else x[0][None] # stack x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 @@ -498,12 +506,12 @@ def forward(self, imgs, size=640, augment=False, profile=False): with amp.autocast(enabled=p.device.type != 'cpu'): # Inference - y = self.model(x, augment, profile)[0] # forward + y = self.model(x, augment, profile) # forward t.append(time_sync()) # Post-process - y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, - multi_label=self.multi_label, max_det=self.max_det) # NMS + y = non_max_suppression(y if self.dmb else y[0], self.conf, iou_thres=self.iou, classes=self.classes, + agnostic=self.agnostic, multi_label=self.multi_label, max_det=self.max_det) # NMS for i in range(n): scale_coords(shape1, y[i][:, :4], shape0[i]) diff --git a/utils/general.py b/utils/general.py index 8aa76fbdb6ad..bbb9054a7235 100755 --- a/utils/general.py +++ b/utils/general.py @@ -455,7 +455,9 @@ def download_one(url, dir): def make_divisible(x, divisor): - # Returns x evenly divisible by divisor + # Returns nearest x divisible by divisor + if isinstance(divisor, torch.Tensor): + divisor = int(divisor.max()) # to int return math.ceil(x / divisor) * divisor