From 3004fb5bc17b5eb3f864d38a6fc9dbc18ec24109 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 21 Dec 2020 15:20:33 -0800 Subject: [PATCH] Automatic m.half() profile on x.half() --- utils/torch_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index c76a318769b8..754b0870cd32 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -88,7 +88,8 @@ def profile(x, ops, n=100, device=None): print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '') print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}") for m in ops if isinstance(ops, list) else [ops]: - m = m.to(device) if hasattr(m, 'to') else m + m = m.to(device) if hasattr(m, 'to') else m # device + m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward try: flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS