Skip to content

Commit

Permalink
Update torch_utils.py
Browse files Browse the repository at this point in the history
FLOPS to GFLOPS
  • Loading branch information
glenn-jocher committed Dec 21, 2020
1 parent 394d1c8 commit 0bd9c48
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def profile(x, ops, n=100, device=None):
x = x.to(device)
x.requires_grad = True
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
print(f"\n{'Params':>12s}{'FLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
for m in ops if isinstance(ops, list) else [ops]:
m = m.to(device) if hasattr(m, 'to') else m
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
Expand Down Expand Up @@ -197,9 +197,9 @@ def model_info(model, verbose=False, img_size=640):
from thop import profile
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
except (ImportError, Exception):
fs = ''

Expand Down

0 comments on commit 0bd9c48

Please sign in to comment.