diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index c5cdd92772f2..b9869df26a43 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -296,7 +296,7 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)): try: p = next(model.parameters()) # for device, type imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand - im = torch.empty((1, 3, *imgsz)).to(p.device).type_as(p) # input image + im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image (WARNING: must be zeros, not empty) with warnings.catch_warnings(): warnings.simplefilter('ignore') # suppress jit trace warning tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])