diff --git a/utils/general.py b/utils/general.py index a104eff..48f1a0c 100644 --- a/utils/general.py +++ b/utils/general.py @@ -1089,6 +1089,9 @@ def output_to_target(output, width, height): targets = [] for i, o in enumerate(output): if o is not None: + # sometimes output can be a list of tensor, so here ensure the type again, this fixes the error. + if isinstance(o, torch.Tensor): + o = o.cpu().numpy() for pred in o: box = pred[:4] w = (box[2] - box[0]) / width