Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for W&B bounding box debugging and metric logging #1108

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 62 additions & 20 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ def test(data,
dataloader=None,
save_dir=Path(''), # for saving images
save_txt=False, # for auto-labelling
plots=True):
save_conf=False,
plots=True,
num_predictions=0):
# Import wandb if logging is enabled
if num_predictions > 0:
import wandb
if num_predictions > 100:
num_predictions = 100

# Initialize/load model and set device
training = model is not None
if training: # called by train.py
Expand All @@ -42,15 +50,17 @@ def test(data,
set_logging()
device = select_device(opt.device, batch_size=batch_size)
save_txt = opt.save_txt # save *.txt labels
if save_txt:
out = Path('inference/output')
if os.path.exists(out):
shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder

# Remove previous
for f in glob.glob(str(save_dir / 'test_batch*.jpg')):
os.remove(f)
if os.path.exists(save_dir):
shutil.rmtree(save_dir) # delete dir
os.makedirs(save_dir) # make new dir

if save_txt:
out = save_dir / 'autolabels'
if os.path.exists(out):
shutil.rmtree(out) # delete dir
os.makedirs(out) # make new dir

# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
Expand Down Expand Up @@ -88,7 +98,7 @@ def test(data,
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], []
jdict, stats, ap, ap_class, wandb_image_log = [], [], [], [], []
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
img = img.to(device, non_blocking=True)
img = img.half() if half else img.float() # uint8 to fp16/32
Expand All @@ -106,7 +116,7 @@ def test(data,

# Compute loss
if training: # if model has loss hyperparameters
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls

# Run NMS
t = time_synchronized()
Expand All @@ -132,8 +142,28 @@ def test(data,
x[:, :4] = scale_coords(img[si].shape[1:], x[:, :4], shapes[si][0], shapes[si][1]) # to original
for *xyxy, conf, cls in x:
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, conf, *xywh) if save_conf else (cls, *xywh) # label format
with open(str(out / Path(paths[si]).stem) + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
f.write(('%g ' * len(line) + '\n') % line)

# Log images with bounding boxes
if len(wandb_image_log) < num_predictions:
x = pred.clone()
bbox_data = [{
"position": {
"minX": float(xyxy[0]),
"minY": float(xyxy[1]),
"maxX": float(xyxy[2]),
"maxY": float(xyxy[3])
},
"class_id": int(cls),
"scores": {
"class_score": float(conf)
},
"domain":"pixel"
} for *xyxy, conf, cls in x]
im = wandb.Image(img[si], boxes={"predictions": {"box_data":bbox_data}})
wandb_image_log.append(im)

# Clip boxes to image bounds
clip_coords(pred, (height, width))
Expand Down Expand Up @@ -187,11 +217,15 @@ def test(data,

# Plot images
if plots and batch_i < 1:
f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename
f = save_dir / f'test_batch{batch_i}_gt.jpg' # filename
plot_images(img, targets, paths, str(f), names) # ground truth
f = save_dir / ('test_batch%g_pred.jpg' % batch_i)
f = save_dir / f'test_batch{batch_i}_pred.jpg'
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions

# Log the images to W&B
if len(wandb_image_log) > 0:
wandb.log({"outputs":wandb_image_log})

# Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any():
Expand All @@ -218,19 +252,19 @@ def test(data,

# Save JSON
if save_json and len(jdict):
f = 'detections_val2017_%s_results.json' % \
(weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename
print('\nCOCO mAP with pycocotools... saving %s...' % f)
with open(f, 'w') as file:
json.dump(jdict, file)
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
file = save_dir / f"detections_val2017_{w}_results.json" # predicted annotations file
print('\nCOCO mAP with pycocotools... saving %s...' % file)
with open(file, 'w') as f:
json.dump(jdict, f)

try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files]
cocoGt = COCO(glob.glob('../coco/annotations/instances_val*.json')[0]) # initialize COCO ground truth api
cocoDt = cocoGt.loadRes(f) # initialize COCO pred api
cocoDt = cocoGt.loadRes(str(file)) # initialize COCO pred api
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
cocoEval.params.imgIds = imgIds # image IDs to evaluate
cocoEval.evaluate()
Expand Down Expand Up @@ -263,6 +297,8 @@ def test(data,
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--save-dir', type=str, default='runs/test', help='directory to save results')
opt = parser.parse_args()
opt.save_json |= opt.data.endswith('coco.yaml')
opt.data = check_file(opt.data) # check file
Expand All @@ -278,7 +314,13 @@ def test(data,
opt.save_json,
opt.single_cls,
opt.augment,
opt.verbose)
opt.verbose,
save_dir=Path(opt.save_dir),
save_txt=opt.save_txt,
save_conf=opt.save_conf,
)

print('Results saved to %s' % opt.save_dir)

elif opt.task == 'study': # run over a range of settings and save/plot
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
Expand Down
Loading