Skip to content

Commit

Permalink
Input channel yaml['ch'] addition (ultralytics#1741)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Dec 19, 2020
1 parent cc5bd96 commit 411802c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
import logging
import math
import sys
from copy import deepcopy
from pathlib import Path

import math
import torch
import torch.nn as nn

Expand Down Expand Up @@ -78,10 +78,11 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels,
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict

# Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']:
logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])

Expand Down
2 changes: 1 addition & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def model_info(model, verbose=False, img_size=640):
try: # FLOPS
from thop import profile
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
Expand Down

0 comments on commit 411802c

Please sign in to comment.