Skip to content

Commit

Permalink
GFLOPs computation fix for classification models (ultralytics#8954)
Browse files Browse the repository at this point in the history
* GFLOPs computation fix for classification models

Improved robustness in reading input channel count

* Update torch_utils.py

* Update torch_utils.py
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 24e08b9 commit 92c81dc
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,11 @@ def sparsity(model):
def prune(model, amount=0.3):
# Prune model to requested global sparsity
import torch.nn.utils.prune as prune
print('Pruning model... ', end='')
for name, m in model.named_modules():
if isinstance(m, nn.Conv2d):
prune.l1_unstructured(m, name='weight', amount=amount) # prune
prune.remove(m, 'weight') # make permanent
print(' %.3g global sparsity' % sparsity(model))
LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')


def fuse_conv_and_bn(conv, bn):
Expand All @@ -230,7 +229,7 @@ def fuse_conv_and_bn(conv, bn):
return fusedconv


def model_info(model, verbose=False, img_size=640):
def model_info(model, verbose=False, imgsz=640):
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
n_p = sum(x.numel() for x in model.parameters()) # number parameters
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
Expand All @@ -242,12 +241,12 @@ def model_info(model, verbose=False, img_size=640):
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))

try: # FLOPs
from thop import profile
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
p = next(model.parameters())
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
except Exception:
fs = ''

Expand Down

0 comments on commit 92c81dc

Please sign in to comment.