From dc80adf3b0c50d54ca38b26a0b7737101555b480 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 21 Aug 2022 15:25:55 +0200 Subject: [PATCH] Revert "`torch.empty()` for speed improvements (#9025)" This reverts commit 61adf017f231f470afca2636f1f13e4cce13914b. --- models/common.py | 4 ++-- models/yolo.py | 6 +++--- utils/autobatch.py | 2 +- utils/loggers/__init__.py | 2 +- utils/torch_utils.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/models/common.py b/models/common.py index 44192e622bb5..33aa2ac12465 100644 --- a/models/common.py +++ b/models/common.py @@ -531,7 +531,7 @@ def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb if any(warmup_types) and self.device.type != 'cpu': - im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input + im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input for _ in range(2 if self.jit else 1): # self.forward(im) # warmup @@ -600,7 +600,7 @@ def forward(self, ims, size=640, augment=False, profile=False): dt = (Profile(), Profile(), Profile()) with dt[0]: - p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param + p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # param autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference if isinstance(ims, torch.Tensor): # torch with amp.autocast(autocast): diff --git a/models/yolo.py b/models/yolo.py index 32a47e9591da..df4209726e0d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -46,8 +46,8 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer self.no = nc + 5 # number of outputs per anchor self.nl = len(anchors) # number of detection layers self.na = len(anchors[0]) // 2 # number of anchors - self.grid = [torch.empty(1)] * self.nl # init grid - self.anchor_grid = [torch.empty(1)] * self.nl # init anchor grid + self.grid = [torch.zeros(1)] * self.nl # init grid + self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2) self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv self.inplace = inplace # use inplace ops (e.g. slice assignment) @@ -175,7 +175,7 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i if isinstance(m, Detect): s = 256 # 2x min stride m.inplace = self.inplace - m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.empty(1, ch, s, s))]) # forward + m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward check_anchor_order(m) # must be in pixel-space (not grid-space) m.anchors /= m.stride.view(-1, 1, 1) self.stride = m.stride diff --git a/utils/autobatch.py b/utils/autobatch.py index 8d12e46f0f09..abae0203f7f4 100644 --- a/utils/autobatch.py +++ b/utils/autobatch.py @@ -47,7 +47,7 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16): # Profile batch sizes batch_sizes = [1, 2, 4, 8, 16] try: - img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] + img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes] results = profile(img, model, n=3, device=device) except Exception as e: LOGGER.warning(f'{prefix}{e}') diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index c5cdd92772f2..619617170d20 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -296,7 +296,7 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)): try: p = next(model.parameters()) # for device, type imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand - im = torch.empty((1, 3, *imgsz)).to(p.device).type_as(p) # input image + im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image with warnings.catch_warnings(): warnings.simplefilter('ignore') # suppress jit trace warning tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), []) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 5fbe8bbf10f6..c1889f9b1dbb 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -281,7 +281,7 @@ def model_info(model, verbose=False, imgsz=640): try: # FLOPs p = next(model.parameters()) stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride - im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs