Skip to content

Commit

Permalink
Refactor new model.warmup() method (ultralytics#5810)
Browse files Browse the repository at this point in the history
* Refactor new `model.warmup()` method

* Add half
  • Loading branch information
glenn-jocher authored Nov 27, 2021
1 parent ca9ad37 commit 8414e75
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
3 changes: 1 addition & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
model.warmup(imgsz=(1, 3, *imgsz), half=half) # warmup
dt, seen = [0.0, 0.0, 0.0], 0
for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync()
Expand Down
7 changes: 7 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,13 @@ def forward(self, im, augment=False, visualize=False, val=False):
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
return (y, []) if val else y

def warmup(self, imgsz=(1, 3, 640, 640), half=False):
# Warmup model by running inference once
if self.pt or self.engine or self.onnx: # warmup types
if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
self.forward(im) # warmup


class AutoShape(nn.Module):
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
Expand Down
3 changes: 1 addition & 2 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def run(data,

# Dataloader
if not training:
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
model.warmup(imgsz=(1, 3, imgsz, imgsz), half=half) # warmup
pad = 0.0 if task == 'speed' else 0.5
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
Expand Down

0 comments on commit 8414e75

Please sign in to comment.