From 394d1c89f33f29b2039c54f4a81e561a0b140b3a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 19 Dec 2020 10:54:01 -0800 Subject: [PATCH] Input channel yaml['ch'] addition (#1741) --- models/yolo.py | 5 +++-- utils/torch_utils.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/models/yolo.py b/models/yolo.py index 4ad44afe5367..695ef761bd3d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -1,10 +1,10 @@ import argparse import logging +import math import sys from copy import deepcopy from pathlib import Path -import math import torch import torch.nn as nn @@ -78,10 +78,11 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict # Define model + ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels if nc and nc != self.yaml['nc']: logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc)) self.yaml['nc'] = nc # override yaml value - self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.names = [str(i) for i in range(self.yaml['nc'])] # default names # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 64efebedd8b4..463837745585 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -196,7 +196,7 @@ def model_info(model, verbose=False, img_size=640): try: # FLOPS from thop import profile stride = int(model.stride.max()) if hasattr(model, 'stride') else 32 - img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input + img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS