Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding --save-conf to test.py and detect.py #1175

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@


def detect(save_img=False):
out, source, weights, view_img, save_txt, imgsz = \
opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
out, source, weights, view_img, save_txt, imgsz, save_conf = \
opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, opt.save_conf
webcam = source.isnumeric() or source.startswith(('rtsp://', 'rtmp://', 'http://')) or source.endswith('.txt')

# Initialize
Expand Down Expand Up @@ -105,7 +105,10 @@ def detect(save_img=False):
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
if save_conf:
f.write(('%g ' * 6 + '\n') % (cls, conf, *xywh)) # label format includes conf
else:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format does not include conf

if save_img or view_img: # Add bbox to image
label = '%s %.2f' % (names[int(cls)], conf)
Expand Down Expand Up @@ -158,6 +161,7 @@ def detect(save_img=False):
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--save-conf', action='store_true', help='put confidence score next to class in label*.txt')
opt = parser.parse_args()
print(opt)

Expand Down
16 changes: 13 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test(data,
dataloader=None,
save_dir=Path(''), # for saving images
save_txt=False, # for auto-labelling
save_conf=False,
plots=True):
# Initialize/load model and set device
training = model is not None
Expand All @@ -43,7 +44,7 @@ def test(data,
device = select_device(opt.device, batch_size=batch_size)
save_txt = opt.save_txt # save *.txt labels
if save_txt:
out = Path('inference/output')
out = save_dir / 'output'
if os.path.exists(out):
shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder
Expand Down Expand Up @@ -133,7 +134,10 @@ def test(data,
for *xyxy, conf, cls in x:
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
with open(str(out / Path(paths[si]).stem) + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
if save_conf:
f.write(('%g ' * 6 + '\n') % (cls, conf, *xywh)) # label format includes conf
else:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format does not include conf

# Clip boxes to image bounds
clip_coords(pred, (height, width))
Expand Down Expand Up @@ -263,6 +267,8 @@ def test(data,
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='put confidence score next to class in label*.txt')
parser.add_argument('--output', type=str, default='', help='output folder') # output folder
opt = parser.parse_args()
opt.save_json |= opt.data.endswith('coco.yaml')
opt.data = check_file(opt.data) # check file
Expand All @@ -278,7 +284,11 @@ def test(data,
opt.save_json,
opt.single_cls,
opt.augment,
opt.verbose)
opt.verbose,
save_dir=Path(opt.output),
save_txt=opt.save_txt,
save_conf=opt.save_conf,
)

elif opt.task == 'study': # run over a range of settings and save/plot
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
Expand Down