Skip to content

Commit

Permalink
Default PyTorch Hub to autocast(False) (#5926)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Dec 8, 2021
1 parent c77a5a8 commit 5bdb28e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ class AutoShape(nn.Module):
multi_label = False # NMS multiple labels per box
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
max_det = 1000 # maximum number of detections per image
amp = False # Automatic Mixed Precision (AMP) inference

def __init__(self, model):
super().__init__()
Expand Down Expand Up @@ -476,8 +477,9 @@ def forward(self, imgs, size=640, augment=False, profile=False):

t = [time_sync()]
p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(imgs, torch.Tensor): # torch
with amp.autocast(enabled=p.device.type != 'cpu'):
with amp.autocast(enabled=autocast):
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference

# Pre-process
Expand Down Expand Up @@ -506,7 +508,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
t.append(time_sync())

with amp.autocast(enabled=p.device.type != 'cpu'):
with amp.autocast(enabled=autocast):
# Inference
y = self.model(x, augment, profile) # forward
t.append(time_sync())
Expand Down

0 comments on commit 5bdb28e

Please sign in to comment.