Skip to content

Commit

Permalink
FLOPS computation device bug fix (#1447)
Browse files Browse the repository at this point in the history
* Update torch_utils.py

fix issue#113 , inputs device should be same with model parameters' device

* Update torch_utils.py

* Update torch_utils.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
JunnYu and glenn-jocher authored Nov 19, 2020
1 parent af8aee7 commit 05a955a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def model_info(model, verbose=False, img_size=640):
try: # FLOPS
from thop import profile
stride = int(model.stride.max())
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
except (ImportError, Exception):
Expand Down

0 comments on commit 05a955a

Please sign in to comment.