Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/production inference diana #7

Merged
merged 4 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,29 @@
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.dataloaders import (
IMG_FORMATS,
VID_FORMATS,
LoadImages,
LoadScreenshots,
LoadStreams,
)
from utils.general import (
LOGGER,
Profile,
check_file,
check_img_size,
check_imshow,
check_requirements,
colorstr,
cv2,
increment_path,
non_max_suppression,
print_args,
scale_boxes,
strip_optimizer,
xyxy2xywh,
)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode

Expand Down Expand Up @@ -145,10 +165,13 @@ def run(
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)

p = Path(p) # to Path
# Removes the absolute mounted path part that changes at every run.
relative_path_in_azure_mounted_folder = Path("/".join(p.parts[p.parts.index("wd")+2:]))
save_path = str(save_dir / relative_path_in_azure_mounted_folder) # im.jpg
txt_path = str(save_dir / 'labels' / relative_path_in_azure_mounted_folder.parent / relative_path_in_azure_mounted_folder.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
is_wd_path = 'wd' in p.parts
relative_path_in_azure_mounted_folder = Path('/'.join(p.parts[p.parts.index('wd') +
2:])) if is_wd_path else None
save_path = str(save_dir / (relative_path_in_azure_mounted_folder if is_wd_path else p.name))
txt_path = str(save_dir / 'labels' /
(relative_path_in_azure_mounted_folder.parent / relative_path_in_azure_mounted_folder.stem
if is_wd_path else p.stem)) + ('' if dataset.mode == 'image' else f'_{frame}')

s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
Expand Down Expand Up @@ -200,7 +223,7 @@ def run(
save_path,
im0,
):
raise Exception(f"Could not write image {os.path.basename(save_path)}")
raise Exception(f'Could not write image {os.path.basename(save_path)}')

# Stream results
im0 = annotator.result()
Expand Down
55 changes: 38 additions & 17 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,36 @@
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm
from utils.augmentations import (Albumentations, augment_hsv,
classify_albumentations, classify_transforms,
copy_paste, letterbox, mixup,
random_perspective)
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT,
check_dataset, check_requirements, check_yaml,
clean_str, cv2, is_colab, is_kaggle, segments2boxes,
unzip_file, xyn2xy, xywh2xyxy, xywhn2xyxy,
xyxy2xywhn)

from utils.augmentations import (
Albumentations,
augment_hsv,
classify_albumentations,
classify_transforms,
copy_paste,
letterbox,
mixup,
random_perspective,
)
from utils.general import (
DATASETS_DIR,
LOGGER,
NUM_THREADS,
TQDM_BAR_FORMAT,
check_dataset,
check_requirements,
check_yaml,
clean_str,
cv2,
is_colab,
is_kaggle,
segments2boxes,
unzip_file,
xyn2xy,
xywh2xyxy,
xywhn2xyxy,
xyxy2xywhn,
)
from utils.torch_utils import torch_distributed_zero_first

# Parameters
Expand Down Expand Up @@ -672,7 +693,7 @@ def __getitem__(self, index):
else:
# Load image
img, (h0, w0), (h, w) = self.load_image(index)

im0 = img.copy()
# Letterbox
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
Expand Down Expand Up @@ -727,7 +748,7 @@ def __getitem__(self, index):
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)

return torch.from_numpy(img), labels_out, self.im_files[index], shapes
return im0, torch.from_numpy(img), labels_out, self.im_files[index], shapes

def load_image(self, i):
# Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
Expand Down Expand Up @@ -889,14 +910,14 @@ def load_mosaic9(self, index):

@staticmethod
def collate_fn(batch):
im, label, path, shapes = zip(*batch) # transposed
im0, im, label, path, shapes = zip(*batch) # transposed
for i, lb in enumerate(label):
lb[:, 0] = i # add target image index for build_targets()
return torch.stack(im, 0), torch.cat(label, 0), path, shapes
return im0, torch.stack(im, 0), torch.cat(label, 0), path, shapes

@staticmethod
def collate_fn4(batch):
im, label, path, shapes = zip(*batch) # transposed
im0, im, label, path, shapes = zip(*batch) # transposed
n = len(shapes) // 4
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

Expand All @@ -918,7 +939,7 @@ def collate_fn4(batch):
for i, lb in enumerate(label4):
lb[:, 0] = i # add target image index for build_targets()

return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
return im0, torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4


# Ancillary functions --------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1027,9 +1048,9 @@ def verify_image_label(args):
nl = len(lb)
if nl:
if lb.shape[1] == 5:
LOGGER.info(f'Loading labels with format [cls x_c y_c width height]')
LOGGER.info('Loading labels with format [cls x_c y_c width height]')
if lb.shape[1] == 6:
LOGGER.info(f'Loading labels with format [cls x_c y_c width height tagged_cls]')
LOGGER.info('Loading labels with format [cls x_c y_c width height tagged_cls]')
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
_, i = np.unique(lb, axis=0, return_index=True)
if len(i) < nl: # duplicate row check
Expand Down
109 changes: 75 additions & 34 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Validate a trained YOLOv5 detection model on a detection dataset

Usage:
$ python val.py --weights yolov5s.pt --data coco128.yaml --img 640
$ python val.py weights/best.pt --data data/pano.yaml --skip_evaluation --save_blurred_image

Usage - formats:
$ python val.py --weights yolov5s.pt # PyTorch
Expand Down Expand Up @@ -49,6 +49,7 @@
check_yaml,
coco80_to_coco91_class,
colorstr,
cv2,
epureanudiana marked this conversation as resolved.
Show resolved Hide resolved
increment_path,
non_max_suppression,
print_args,
Expand Down Expand Up @@ -173,7 +174,9 @@ def run(
plots=True,
callbacks=Callbacks(),
compute_loss=None,
tagged_data=False):
tagged_data=False,
skip_evaluation=False,
save_blurred_image=False):
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
Expand Down Expand Up @@ -245,7 +248,8 @@ def run(
jdict, stats, ap, ap_class = [], [], [], []
callbacks.run('on_val_start')
pbar = tqdm(dataloader, desc=s, bar_format=TQDM_BAR_FORMAT) # progress bar
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
for batch_i, (im0, im, targets, paths, shapes) in enumerate(pbar):

callbacks.run('on_val_batch_start')
if tagged_data:
confusion_matrix = TaggedConfusionMatrix(nc=nc)
Expand Down Expand Up @@ -291,8 +295,14 @@ def run(
correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
seen += 1

p = Path(path) # to Path
is_wd_path = 'wd' in p.parts
relative_path_in_azure_mounted_folder = Path('/'.join(p.parts[p.parts.index('wd') +
2:])) if is_wd_path else None
save_path = str(save_dir / (relative_path_in_azure_mounted_folder if is_wd_path else p.name))

if npr == 0:
if nl:
if nl and not skip_evaluation:
stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))
if plots:
confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
Expand All @@ -302,20 +312,22 @@ def run(
if single_cls:
pred[:, 5] = 0
predn = pred.clone()
pred_clone = pred.clone()
scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred

# Evaluate
if nl:
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct = process_batch(predn, labelsn, iouv)
if plots:
if tagged_data:
confusion_matrix.process_batch(predn, labelsn, gt_boxes, tagged_labels)
else:
confusion_matrix.process_batch(predn, labelsn)
stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)
if not skip_evaluation:
if nl:
tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
correct = process_batch(predn, labelsn, iouv)
if plots:
if tagged_data:
confusion_matrix.process_batch(predn, labelsn, gt_boxes, tagged_labels)
else:
confusion_matrix.process_batch(predn, labelsn)
stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)

# Save/log
if save_txt:
Expand All @@ -332,31 +344,54 @@ def run(
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
callbacks.run('on_val_image_end', pred, predn, path, names, im[si])

# ======== SAVE BLURRED ======== #
if save_blurred_image:
pred_clone[:, :4] = scale_boxes(im.shape[2:], pred_clone[:, :4],
im0[si].shape).round() # TODO why not im[si].shape[2:]

for *xyxy, conf, cls in pred_clone.tolist():
x1, y1 = int(xyxy[0]), int(xyxy[1])
x2, y2 = int(xyxy[2]), int(xyxy[3])
area_to_blur = im0[si][y1:y2, x1:x2]

blurred = cv2.GaussianBlur(area_to_blur, (135, 135), 0)
im0[si][y1:y2, x1:x2] = blurred

folder_path = os.path.dirname(save_path)
if not os.path.exists(folder_path):
os.makedirs(folder_path)
if not cv2.imwrite(
save_path,
im0[si],
):
raise Exception(f'Could not write image {os.path.basename(save_path)}')
# ======== END SAVE BLURRED ======== #
# Plot images
if plots:
if plots and not skip_evaluation:
plot_images(im, targets, paths, save_dir / f'{path.stem}.jpg', names) # labels
plot_images(im, output_to_target(preds), paths, save_dir / f'{path.stem}_pred.jpg', names) # pred

callbacks.run('on_val_batch_end', batch_i, im, targets, paths, shapes, preds)

# Compute metrics
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any():
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class

# Print results
pf = '%22s' + '%11i' * 2 + '%11.3g' * 4 # print format
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
if nt.sum() == 0:
LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels')

# Print results per class
if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
for i, c in enumerate(ap_class):
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
if not skip_evaluation:
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any():
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class

# Print results
pf = '%22s' + '%11i' * 2 + '%11.3g' * 4 # print format
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
if nt.sum() == 0:
LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels')

# Print results per class
if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
for i, c in enumerate(ap_class):
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))

# Print speeds
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
Expand All @@ -365,7 +400,7 @@ def run(
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)

# Plots
if plots:
if plots and not skip_evaluation:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
callbacks.run('on_val_end', nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix)

Expand Down Expand Up @@ -403,6 +438,8 @@ def run(
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
maps[c] = ap[i]
if skip_evaluation:
return (mp, mr, map50, map, []), maps, t
return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t


Expand Down Expand Up @@ -431,6 +468,8 @@ def parse_opt():
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
parser.add_argument('--tagged-data', action='store_true', help='use tagged validation')
parser.add_argument('--skip-evaluation', action='store_true', help='ignore code parts for production')
parser.add_argument('--save_blurred_image', action='store_true', help='save blurred images')
opt = parser.parse_args()
opt.data = check_yaml(opt.data) # check YAML
opt.save_json |= opt.data.endswith('coco.yaml')
Expand All @@ -443,6 +482,8 @@ def main(opt):
check_requirements(exclude=('tensorboard', 'thop'))

if opt.task in ('train', 'val', 'test'): # run normally
if opt.skip_evaluation:
opt.conf_thres, opt.iou_thres = 0.25, 0.45
if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466
LOGGER.info(f'WARNING ⚠️ confidence threshold {opt.conf_thres} > 0.001 produces invalid results')
if opt.save_hybrid:
Expand Down