Skip to content

Commit

Permalink
Link fuse() to AutoShape() for Hub models (ultralytics#8599)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 895ffba commit 500efb5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo

if not verbose:
LOGGER.setLevel(logging.WARNING)

check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
name = Path(name)
path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
try:
device = select_device(device)

if pretrained and channels == 3 and classes == 80:
model = DetectMultiBackend(path, device=device) # download/load FP32 model
model = DetectMultiBackend(path, device=device, fuse=autoshape) # download/load FP32 model
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
Expand Down
4 changes: 2 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def forward(self, x):

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
Expand All @@ -331,7 +331,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
names = yaml.safe_load(f)['names']

if pt: # PyTorch
model = attempt_load(weights if isinstance(weights, list) else w, device=device)
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
stride = max(int(model.stride.max()), 32) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
model.half() if fp16 else model.float()
Expand Down

0 comments on commit 500efb5

Please sign in to comment.