From f2066945cb179f9eb5d208050b57167c866bbee9 Mon Sep 17 00:00:00 2001 From: burhany60 Date: Sat, 17 Apr 2021 13:50:52 +0800 Subject: [PATCH 01/13] Update detect.py --- detect.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/detect.py b/detect.py index c0707da69e6a..a1e0134f1973 100644 --- a/detect.py +++ b/detect.py @@ -1,6 +1,7 @@ import argparse import time from pathlib import Path +import numpy as np import cv2 import torch @@ -16,7 +17,7 @@ def detect(save_img=False): - source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size + source, weights, view_img, save_obj, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_obj, opt.save_txt, opt.img_size save_img = not opt.nosave and not source.endswith('.txt') # save inference images webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( ('rtsp://', 'rtmp://', 'http://', 'https://')) @@ -24,7 +25,8 @@ def detect(save_img=False): # Directories save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir - + (save_dir / 'cropped' if save_obj else save_dir).mkdir(parents=True, exist_ok=True) # make dir for cropped objects + # Initialize set_logging() device = select_device(opt.device) @@ -85,7 +87,8 @@ def detect(save_img=False): p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count else: p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) - + + im1=im0.copy() # making a copy of the original image p = Path(p) # to Path save_path = str(save_dir / p.name) # img.jpg txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt @@ -101,6 +104,7 @@ def detect(save_img=False): s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string # Write results + k = 0 # counter for each object in an image for *xyxy, conf, cls in reversed(det): if save_txt: # Write to file xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh @@ -111,6 +115,22 @@ def detect(save_img=False): if save_img or view_img: # Add bbox to image label = f'{names[int(cls)]} {conf:.2f}' plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) + + if save_obj: # save detected objects as a separate images + x,y,w,h=int(xyxy[0]), int(xyxy[1]), int(xyxy[2] - xyxy[0]), int(xyxy[3] - xyxy[1]) + img_ = im1.astype(np.uint8) + crop_img=img_[y:y + h, x:x + w] + + #!!Generating new file path for each detected object in an image !!! + filename=p.name + filename_no_extesion=filename.split('.')[0] + extension=filename.split('.')[1] + new_filename=str(filename_no_extesion) + '_' + str(k) + '.' + str(extension) + dir_path=os.path.join(save_dir,'cropped') + filepath=os.path.join(dir_path, new_filename) + print(filepath) + cv2.imwrite(filepath, crop_img) + k+=1 # Print time (inference + NMS) print(f'{s}Done. ({t2 - t1:.3f}s)') @@ -156,6 +176,7 @@ def detect(save_img=False): parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') + parser.add_argument('--save_obj', action='store_false', help='save the detected object as separate image') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') parser.add_argument('--nosave', action='store_true', help='do not save images/videos') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') From f8bb3d64683638f51c74c23edaf43440cf28a022 Mon Sep 17 00:00:00 2001 From: burhany60 Date: Sat, 17 Apr 2021 14:01:23 +0800 Subject: [PATCH 02/13] Update detect.py --- detect.py | 1 + 1 file changed, 1 insertion(+) diff --git a/detect.py b/detect.py index a1e0134f1973..834bf1f6233d 100644 --- a/detect.py +++ b/detect.py @@ -2,6 +2,7 @@ import time from pathlib import Path import numpy as np +import os import cv2 import torch From d1985cb25e5982a6be8bfb4aff34c7ae6c8316bc Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 18:32:56 +0200 Subject: [PATCH 03/13] Update greetings.yml --- .github/workflows/greetings.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/greetings.yml b/.github/workflows/greetings.yml index 59c45067ad88..ee472297107e 100644 --- a/.github/workflows/greetings.yml +++ b/.github/workflows/greetings.yml @@ -16,7 +16,7 @@ jobs: git remote add upstream https://github.com/ultralytics/yolov5.git git fetch upstream git checkout feature # <----- replace 'feature' with local branch name - git merge upstream/master + git rebase upstream/master git push -u origin -f ``` - ✅ Verify all Continuous Integration (CI) **checks are passing**. From 0f719bbea8339679ab0dc33e125e02011044d854 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 21:54:07 +0200 Subject: [PATCH 04/13] Update cropping --- detect.py | 39 ++++++++++----------------------------- utils/general.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/detect.py b/detect.py index 834bf1f6233d..5ce749c65232 100644 --- a/detect.py +++ b/detect.py @@ -1,8 +1,6 @@ import argparse import time from pathlib import Path -import numpy as np -import os import cv2 import torch @@ -12,13 +10,13 @@ from models.experimental import attempt_load from utils.datasets import LoadStreams, LoadImages from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ - scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path + scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box from utils.plots import plot_one_box from utils.torch_utils import select_device, load_classifier, time_synchronized -def detect(save_img=False): - source, weights, view_img, save_obj, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_obj, opt.save_txt, opt.img_size +def detect(): + source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size save_img = not opt.nosave and not source.endswith('.txt') # save inference images webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( ('rtsp://', 'rtmp://', 'http://', 'https://')) @@ -26,8 +24,7 @@ def detect(save_img=False): # Directories save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir - (save_dir / 'cropped' if save_obj else save_dir).mkdir(parents=True, exist_ok=True) # make dir for cropped objects - + # Initialize set_logging() device = select_device(opt.device) @@ -87,9 +84,8 @@ def detect(save_img=False): if webcam: # batch_size >= 1 p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count else: - p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) - - im1=im0.copy() # making a copy of the original image + p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0) + p = Path(p) # to Path save_path = str(save_dir / p.name) # img.jpg txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt @@ -105,7 +101,6 @@ def detect(save_img=False): s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string # Write results - k = 0 # counter for each object in an image for *xyxy, conf, cls in reversed(det): if save_txt: # Write to file xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh @@ -113,25 +108,11 @@ def detect(save_img=False): with open(txt_path + '.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') - if save_img or view_img: # Add bbox to image + if save_img or opt.save_crops or view_img: # Add bbox to image label = f'{names[int(cls)]} {conf:.2f}' plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) - - if save_obj: # save detected objects as a separate images - x,y,w,h=int(xyxy[0]), int(xyxy[1]), int(xyxy[2] - xyxy[0]), int(xyxy[3] - xyxy[1]) - img_ = im1.astype(np.uint8) - crop_img=img_[y:y + h, x:x + w] - - #!!Generating new file path for each detected object in an image !!! - filename=p.name - filename_no_extesion=filename.split('.')[0] - extension=filename.split('.')[1] - new_filename=str(filename_no_extesion) + '_' + str(k) + '.' + str(extension) - dir_path=os.path.join(save_dir,'cropped') - filepath=os.path.join(dir_path, new_filename) - print(filepath) - cv2.imwrite(filepath, crop_img) - k+=1 + if opt.save_crops: + save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[int(cls)] / f'{p.stem}.jpg') # Print time (inference + NMS) print(f'{s}Done. ({t2 - t1:.3f}s)') @@ -177,8 +158,8 @@ def detect(save_img=False): parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') - parser.add_argument('--save_obj', action='store_false', help='save the detected object as separate image') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') + parser.add_argument('--save-crops', action='store_false', help='save cropped prediction boxes') parser.add_argument('--nosave', action='store_true', help='do not save images/videos') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') diff --git a/utils/general.py b/utils/general.py index c7d084e09326..2d228e556655 100755 --- a/utils/general.py +++ b/utils/general.py @@ -591,6 +591,23 @@ def apply_classifier(x, model, img, im0): return x +def save_one_box(xyxy, im, file='crop.jpg', gain=1.02, pad=10, square=False): + # save an image crop as filename.jpg (crop size multiplied by 'gain' and padded by 'pad' pixels) + xyxy = torch.tensor(xyxy).view(-1, 4) + b = xyxy2xywh(xyxy) # boxes + if square: + b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square + b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad + xyxy = xywh2xyxy(b).long() + clip_coords(xyxy, im.shape) + crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2])] + + file = Path(increment_path(file)).with_suffix('.jpg') + if not file.parent.exists(): + file.parent.mkdir(parents=True, exist_ok=True) # make dir + cv2.imwrite(str(file), crop) + + def increment_path(path, exist_ok=False, sep=''): # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. path = Path(path) # os-agnostic From dceec777ad05a9c9e3bfc727fe979eb7b55e7dbf Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 21:59:03 +0200 Subject: [PATCH 05/13] cleanup --- utils/general.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/general.py b/utils/general.py index 2d228e556655..18b1dd7a877c 100755 --- a/utils/general.py +++ b/utils/general.py @@ -557,7 +557,7 @@ def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''): def apply_classifier(x, model, img, im0): - # applies a second stage classifier to yolo outputs + # Apply a second stage classifier to yolo outputs im0 = [im0] if isinstance(im0, np.ndarray) else im0 for i, d in enumerate(x): # per image if d is not None and len(d): @@ -591,8 +591,8 @@ def apply_classifier(x, model, img, im0): return x -def save_one_box(xyxy, im, file='crop.jpg', gain=1.02, pad=10, square=False): - # save an image crop as filename.jpg (crop size multiplied by 'gain' and padded by 'pad' pixels) +def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False): + # Save an image crop as {file} with crop size multiplied by {gain} and padded by {pad} pixels xyxy = torch.tensor(xyxy).view(-1, 4) b = xyxy2xywh(xyxy) # boxes if square: From dbb7465ca0119a640e9b156ba7290a2db6196c5d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 22:43:11 +0200 Subject: [PATCH 06/13] Update increment_path() --- detect.py | 2 +- test.py | 2 +- train.py | 6 +++--- utils/general.py | 17 +++++++---------- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/detect.py b/detect.py index 5ce749c65232..d666ff7c201f 100644 --- a/detect.py +++ b/detect.py @@ -22,7 +22,7 @@ def detect(): ('rtsp://', 'rtmp://', 'http://', 'https://')) # Directories - save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run + save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir # Initialize diff --git a/test.py b/test.py index d099699bcad8..db1651d07f65 100644 --- a/test.py +++ b/test.py @@ -49,7 +49,7 @@ def test(data, device = select_device(opt.device, batch_size=batch_size) # Directories - save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run + save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir # Load model diff --git a/train.py b/train.py index 82043b7fff34..17b5ac5dda50 100644 --- a/train.py +++ b/train.py @@ -41,7 +41,7 @@ def train(hyp, opt, device, tb_writer=None): logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) save_dir, epochs, batch_size, total_batch_size, weights, rank = \ - Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank + opt.save_dir, opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank # Directories wdir = save_dir / 'weights' @@ -69,7 +69,7 @@ def train(hyp, opt, device, tb_writer=None): if rank in [-1, 0]: opt.hyp = hyp # add hyperparameters run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None - wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict) + wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict) loggers['wandb'] = wandb_logger.wandb data_dict = wandb_logger.data_dict if wandb_logger.wandb: @@ -577,7 +577,7 @@ def train(hyp, opt, device, tb_writer=None): assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' opt.notest, opt.nosave = True, True # only test/save final epoch # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices - yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here + yaml_file = opt.save_dir / 'hyp_evolved.yaml' # save best result here if opt.bucket: os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists diff --git a/utils/general.py b/utils/general.py index 18b1dd7a877c..8c5c9c1cf0e6 100755 --- a/utils/general.py +++ b/utils/general.py @@ -601,23 +601,20 @@ def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False): xyxy = xywh2xyxy(b).long() clip_coords(xyxy, im.shape) crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2])] + cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop) - file = Path(increment_path(file)).with_suffix('.jpg') - if not file.parent.exists(): - file.parent.mkdir(parents=True, exist_ok=True) # make dir - cv2.imwrite(str(file), crop) - -def increment_path(path, exist_ok=False, sep=''): +def increment_path(path, exist_ok=False, sep='', mkdir=False): # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. path = Path(path) # os-agnostic - if not path.exists() or exist_ok: - return str(path) - else: + if path.exists() and not exist_ok: suffix = path.suffix path = path.with_suffix('') dirs = glob.glob(f"{path}{sep}*") # similar paths matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] i = [int(m.groups()[0]) for m in matches if m] # indices n = max(i) + 1 if i else 2 # increment number - return f"{path}{sep}{n}{suffix}" # update path + path = Path(f"{path}{sep}{n}{suffix}") # update path + if mkdir: + (path.parent if path.is_file() else path).mkdir(parents=True, exist_ok=True) # make directory + return path From 6f04c41d3a45b1e58944a2a7bf9f2557d079798d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 22:52:57 +0200 Subject: [PATCH 07/13] Update common.py --- models/common.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/models/common.py b/models/common.py index 2fdc0e0b70ca..ee5ee39fc2c5 100644 --- a/models/common.py +++ b/models/common.py @@ -13,7 +13,7 @@ from torch.cuda import amp from utils.datasets import letterbox -from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh +from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box from utils.plots import color_list, plot_one_box from utils.torch_utils import time_synchronized @@ -311,29 +311,33 @@ def __init__(self, imgs, pred, files, times=None, names=None, shape=None): self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms) self.s = shape # inference BCHW shape - def display(self, pprint=False, show=False, save=False, render=False, save_dir=''): + def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')): colors = color_list() - for i, (img, pred) in enumerate(zip(self.imgs, self.pred)): - str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} ' + for i, (im, pred) in enumerate(zip(self.imgs, self.pred)): + str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' if pred is not None: for c in pred[:, -1].unique(): n = (pred[:, -1] == c).sum() # detections per class str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string - if show or save or render: + if show or save or render or crop: for *box, conf, cls in pred: # xyxy, confidence, class label = f'{self.names[int(cls)]} {conf:.2f}' - plot_one_box(box, img, label=label, color=colors[int(cls) % 10]) - img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np + if crop: + save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i]) + else: # all others + plot_one_box(box, im, label=label, color=colors[int(cls) % 10]) + + im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np if pprint: print(str.rstrip(', ')) if show: - img.show(self.files[i]) # show + im.show(self.files[i]) # show if save: f = self.files[i] - img.save(Path(save_dir) / f) # save + im.save(save_dir / f) # save print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n') if render: - self.imgs[i] = np.asarray(img) + self.imgs[i] = np.asarray(im) def print(self): self.display(pprint=True) # print results @@ -343,10 +347,13 @@ def show(self): self.display(show=True) # show results def save(self, save_dir='runs/hub/exp'): - save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir - Path(save_dir).mkdir(parents=True, exist_ok=True) + save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir self.display(save=True, save_dir=save_dir) # save results + def crop(self, save_dir='runs/hub/exp'): + save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir + self.display(crop=True, save_dir=save_dir) # crop results + def render(self): self.display(render=True) # render results return self.imgs From 2dd5f6d97a46c85df8aa3c21c3723338e76338df Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 22:54:08 +0200 Subject: [PATCH 08/13] Update detect.py --- detect.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/detect.py b/detect.py index d666ff7c201f..dcce2316759c 100644 --- a/detect.py +++ b/detect.py @@ -108,10 +108,10 @@ def detect(): with open(txt_path + '.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') - if save_img or opt.save_crops or view_img: # Add bbox to image + if save_img or opt.save_crop or view_img: # Add bbox to image label = f'{names[int(cls)]} {conf:.2f}' plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) - if opt.save_crops: + if opt.save_crop: save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[int(cls)] / f'{p.stem}.jpg') # Print time (inference + NMS) @@ -159,7 +159,7 @@ def detect(): parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') - parser.add_argument('--save-crops', action='store_false', help='save cropped prediction boxes') + parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes') parser.add_argument('--nosave', action='store_true', help='do not save images/videos') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') From e1f8617116b6396d0ea3fa585a2568b93a539754 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 23:12:20 +0200 Subject: [PATCH 09/13] Update detect.py --- detect.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/detect.py b/detect.py index dcce2316759c..0319cb88ac0e 100644 --- a/detect.py +++ b/detect.py @@ -108,11 +108,12 @@ def detect(): with open(txt_path + '.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') - if save_img or opt.save_crop or view_img: # Add bbox to image - label = f'{names[int(cls)]} {conf:.2f}' - plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) - if opt.save_crop: - save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[int(cls)] / f'{p.stem}.jpg') + if save_img or opt.save_crops or view_img: # Add bbox to image + c = int(cls) # integer class + label = f'{names[c]} {conf:.2f}' + plot_one_box(xyxy, im0, label=label, color=colors[c], line_thickness=3) + if opt.save_crops: + save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) # Print time (inference + NMS) print(f'{s}Done. ({t2 - t1:.3f}s)') @@ -159,7 +160,7 @@ def detect(): parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') - parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes') + parser.add_argument('--save-crops', action='store_false', help='save cropped prediction boxes') parser.add_argument('--nosave', action='store_true', help='do not save images/videos') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') From d33a7dd0afee52eed871ef2c48f7ddf75e8a553d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 23:14:29 +0200 Subject: [PATCH 10/13] Update detect.py --- detect.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/detect.py b/detect.py index 0319cb88ac0e..9280a07a48b9 100644 --- a/detect.py +++ b/detect.py @@ -108,11 +108,11 @@ def detect(): with open(txt_path + '.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') - if save_img or opt.save_crops or view_img: # Add bbox to image + if save_img or opt.save_crop or view_img: # Add bbox to image c = int(cls) # integer class label = f'{names[c]} {conf:.2f}' plot_one_box(xyxy, im0, label=label, color=colors[c], line_thickness=3) - if opt.save_crops: + if opt.save_crop: save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) # Print time (inference + NMS) @@ -160,7 +160,7 @@ def detect(): parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') - parser.add_argument('--save-crops', action='store_false', help='save cropped prediction boxes') + parser.add_argument('--save-crop', action='store_false', help='save cropped prediction boxes') parser.add_argument('--nosave', action='store_true', help='do not save images/videos') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') From 7f8057c2fe52ef94ebafb115e79432294a87acb4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 23:17:24 +0200 Subject: [PATCH 11/13] Update common.py --- models/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/common.py b/models/common.py index ee5ee39fc2c5..098b29127876 100644 --- a/models/common.py +++ b/models/common.py @@ -353,6 +353,7 @@ def save(self, save_dir='runs/hub/exp'): def crop(self, save_dir='runs/hub/exp'): save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir self.display(crop=True, save_dir=save_dir) # crop results + print(f'Saved crops to {save_dir}\n') def render(self): self.display(render=True) # render results From 8e6dc74f4b11dd0ce5f460d05067b6d91d9ee6c3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 23:28:26 +0200 Subject: [PATCH 12/13] cleanup --- models/common.py | 2 +- utils/general.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/models/common.py b/models/common.py index 098b29127876..a28621904b0e 100644 --- a/models/common.py +++ b/models/common.py @@ -353,7 +353,7 @@ def save(self, save_dir='runs/hub/exp'): def crop(self, save_dir='runs/hub/exp'): save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir self.display(crop=True, save_dir=save_dir) # crop results - print(f'Saved crops to {save_dir}\n') + print(f'Saved results to {save_dir}\n') def render(self): self.display(render=True) # render results diff --git a/utils/general.py b/utils/general.py index 8c5c9c1cf0e6..817023f33dd3 100755 --- a/utils/general.py +++ b/utils/general.py @@ -591,7 +591,7 @@ def apply_classifier(x, model, img, im0): return x -def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False): +def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False): # Save an image crop as {file} with crop size multiplied by {gain} and padded by {pad} pixels xyxy = torch.tensor(xyxy).view(-1, 4) b = xyxy2xywh(xyxy) # boxes @@ -601,7 +601,7 @@ def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False): xyxy = xywh2xyxy(b).long() clip_coords(xyxy, im.shape) crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2])] - cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop) + cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop if BGR else crop[..., ::-1]) def increment_path(path, exist_ok=False, sep='', mkdir=False): @@ -615,6 +615,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False): i = [int(m.groups()[0]) for m in matches if m] # indices n = max(i) + 1 if i else 2 # increment number path = Path(f"{path}{sep}{n}{suffix}") # update path - if mkdir: - (path.parent if path.is_file() else path).mkdir(parents=True, exist_ok=True) # make directory + dir = path if path.suffix == '' else path.parent # directory + if not dir.exists() and mkdir: + dir.mkdir(parents=True, exist_ok=True) # make directory return path From d090ce181bd9b4be9836e95f20c273c30c42c973 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 20 Apr 2021 23:32:34 +0200 Subject: [PATCH 13/13] Update detect.py --- detect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/detect.py b/detect.py index 9280a07a48b9..081ae3d89e2e 100644 --- a/detect.py +++ b/detect.py @@ -160,7 +160,7 @@ def detect(): parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') - parser.add_argument('--save-crop', action='store_false', help='save cropped prediction boxes') + parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes') parser.add_argument('--nosave', action='store_true', help='do not save images/videos') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')