Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding --save-conf to test.py and detect.py #1175

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test(data,
dataloader=None,
save_dir=Path(''), # for saving images
save_txt=False, # for auto-labelling
save_conf=False,
plots=True):
# Initialize/load model and set device
training = model is not None
Expand All @@ -43,7 +44,7 @@ def test(data,
device = select_device(opt.device, batch_size=batch_size)
save_txt = opt.save_txt # save *.txt labels
if save_txt:
out = Path('inference/output')
out = save_dir / 'output'
if os.path.exists(out):
shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder
Expand Down Expand Up @@ -132,8 +133,9 @@ def test(data,
x[:, :4] = scale_coords(img[si].shape[1:], x[:, :4], shapes[si][0], shapes[si][1]) # to original
for *xyxy, conf, cls in x:
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, conf, *xywh) if save_conf else (cls, *xywh) # label format
with open(str(out / Path(paths[si]).stem) + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
f.write(('%g ' * len(line) + '\n') % line)

# Clip boxes to image bounds
clip_coords(pred, (height, width))
Expand Down Expand Up @@ -263,6 +265,8 @@ def test(data,
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='put confidence score next to class in label*.txt')
parser.add_argument('--output', type=str, default='', help='output folder') # output folder
opt = parser.parse_args()
opt.save_json |= opt.data.endswith('coco.yaml')
opt.data = check_file(opt.data) # check file
Expand All @@ -278,7 +282,11 @@ def test(data,
opt.save_json,
opt.single_cls,
opt.augment,
opt.verbose)
opt.verbose,
save_dir=Path(opt.output),
save_txt=opt.save_txt,
save_conf=opt.save_conf,
)

elif opt.task == 'study': # run over a range of settings and save/plot
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
Expand Down