Skip to content

Commit

Permalink
fix yolo-nas int bug: confs always zero
Browse files Browse the repository at this point in the history
  • Loading branch information
mikel.brostrom committed May 27, 2023
1 parent 3522411 commit b292ef6
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions examples/yolo_nas_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
import super_gradients # for linear_assignment
except (ImportError, AssertionError, AttributeError):
from ultralytics.yolo.utils.checks import check_requirements

check_requirements('super_gradients') # install
import lap
import super_gradients


def on_predict_start(predictor):
Expand Down Expand Up @@ -128,7 +127,8 @@ def run(args):
prediction.labels[:, np.newaxis]
], axis=1
)
preds = torch.from_numpy(preds).int()
preds = torch.from_numpy(preds)
preds[:, 0:4] = preds[:, 0:4].int()
predictor.results = [None]
# # Postprocess
with predictor.profilers[2]:
Expand All @@ -149,14 +149,11 @@ def run(args):
dets = predictor.results[i].boxes.data
# get tracker predictions
predictor.tracker_outputs[i] = predictor.trackers[i].update(dets, im0)
print(predictor.tracker_outputs[i].shape)
print(predictor.tracker_outputs[i])
predictor.results[i].speed = {
'preprocess': predictor.profilers[0].dt * 1E3 / n,
'inference': predictor.profilers[1].dt * 1E3 / n,
'postprocess': predictor.profilers[2].dt * 1E3 / n,
'tracking': predictor.profilers[3].dt * 1E3 / n

}

# overwrite bbox results with tracker predictions
Expand Down

0 comments on commit b292ef6

Please sign in to comment.