From d8f18834a246cfe3589406635c7e990f8043130a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 30 Jul 2021 18:17:19 +0200 Subject: [PATCH] Update `profile()` for CUDA Memory allocation (#4239) * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Update profile() * Cleanup --- tutorial.ipynb | 4 +-- utils/torch_utils.py | 76 ++++++++++++++++++++++++++------------------ 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/tutorial.ipynb b/tutorial.ipynb index 8d9c3f8b7a15..b16506275288 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -1172,11 +1172,11 @@ }, "source": [ "# Profile\n", - "from utils.torch_utils import profile \n", + "from utils.torch_utils import profile\n", "\n", "m1 = lambda x: x * torch.sigmoid(x)\n", "m2 = torch.nn.SiLU()\n", - "profile(x=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)" + "results = profile(input=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)" ], "execution_count": null, "outputs": [] diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 55a5fd7875bb..4956cf95d1ca 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -98,42 +98,56 @@ def time_sync(): return time.time() -def profile(x, ops, n=100, device=None): - # profile a pytorch module or list of modules. Example usage: - # x = torch.randn(16, 3, 640, 640) # input +def profile(input, ops, n=10, device=None): + # YOLOv5 speed/memory/FLOPs profiler + # + # Usage: + # input = torch.randn(16, 3, 640, 640) # m1 = lambda x: x * torch.sigmoid(x) # m2 = nn.SiLU() - # profile(x, [m1, m2], n=100) # profile speed over 100 iterations + # profile(input, [m1, m2], n=100) # profile over 100 iterations + results = [] device = device or select_device() - x = x.to(device) - x.requires_grad = True - print(f"{'Params':>12s}{'GFLOPs':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}") - for m in ops if isinstance(ops, list) else [ops]: - m = m.to(device) if hasattr(m, 'to') else m # device - m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type - dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward - try: - flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs - except: - flops = 0 - - for _ in range(n): - t[0] = time_sync() - y = m(x) - t[1] = time_sync() + print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" + f"{'input':>24s}{'output':>24s}") + + for x in input if isinstance(input, list) else [input]: + x = x.to(device) + x.requires_grad = True + for m in ops if isinstance(ops, list) else [ops]: + m = m.to(device) if hasattr(m, 'to') else m # device + m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m + tf, tb, t = 0., 0., [0., 0., 0.] # dt forward, backward try: - _ = y.sum().backward() - t[2] = time_sync() - except: # no backward method - t[2] = float('nan') - dtf += (t[1] - t[0]) * 1000 / n # ms per op forward - dtb += (t[2] - t[1]) * 1000 / n # ms per op backward - - s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' - s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list' - p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters - print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}') + flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs + except: + flops = 0 + + try: + for _ in range(n): + t[0] = time_sync() + y = m(x) + t[1] = time_sync() + try: + _ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward() + t[2] = time_sync() + except Exception as e: # no backward method + print(e) + t[2] = float('nan') + tf += (t[1] - t[0]) * 1000 / n # ms per op forward + tb += (t[2] - t[1]) * 1000 / n # ms per op backward + mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB) + s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' + s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list' + p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters + print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}') + results.append([p, flops, mem, tf, tb, s_in, s_out]) + except Exception as e: + print(e) + results.append(None) + torch.cuda.empty_cache() + return results def is_parallel(model):