diff --git a/test.py b/test.py index 0093b8c09f54..f8936d3b4f9d 100644 --- a/test.py +++ b/test.py @@ -188,8 +188,8 @@ def test(data, # Per target class for cls in torch.unique(tcls_tensor): - ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices - pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices + ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # target indices + pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # prediction indices # Search for detections if pi.shape[0]: