diff --git a/models/common.py b/models/common.py index 2acf6281f475..56077c84acf9 100644 --- a/models/common.py +++ b/models/common.py @@ -289,6 +289,14 @@ def autoshape(self): LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() return self + def _apply(self, fn): + # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers + self = super()._apply(fn) + m = self.model.model[-1] # Detect() + m.stride = fn(m.stride) + m.grid = list(map(fn, m.grid)) + return self + @torch.no_grad() def forward(self, imgs, size=640, augment=False, profile=False): # Inference from various sources. For height=640, width=1280, RGB images example inputs are: diff --git a/models/yolo.py b/models/yolo.py index 5087a9f7399d..74b266bff00d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -232,6 +232,15 @@ def autoshape(self): # add AutoShape module def info(self, verbose=False, img_size=640): # print model information model_info(self, verbose, img_size) + def _apply(self, fn): + # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers + self = super()._apply(fn) + m = self.model[-1] # Detect() + if isinstance(m, Detect): + m.stride = fn(m.stride) + m.grid = list(map(fn, m.grid)) + return self + def parse_model(d, ch): # model_dict, input_channels(3) LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))