Skip to content

Commit

Permalink
Detection cropping+saving feature addition for detect.py and PyTorch …
Browse files Browse the repository at this point in the history
…Hub (ultralytics#2827)

* Update detect.py

* Update detect.py

* Update greetings.yml

* Update cropping

* cleanup

* Update increment_path()

* Update common.py

* Update detect.py

* Update detect.py

* Update detect.py

* Update common.py

* cleanup

* Update detect.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
burhr2 and glenn-jocher committed Apr 20, 2021
1 parent 71c6e40 commit 0cd5de7
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 29 deletions.
18 changes: 11 additions & 7 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized


def detect(save_img=False):
def detect():
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
save_img = not opt.nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
('rtsp://', 'rtmp://', 'http://', 'https://'))

# Directories
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

# Initialize
Expand Down Expand Up @@ -84,7 +84,7 @@ def detect(save_img=False):
if webcam: # batch_size >= 1
p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
else:
p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)

p = Path(p) # to Path
save_path = str(save_dir / p.name) # img.jpg
Expand All @@ -108,9 +108,12 @@ def detect(save_img=False):
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')

if save_img or view_img: # Add bbox to image
label = f'{names[int(cls)]} {conf:.2f}'
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
if save_img or opt.save_crop or view_img: # Add bbox to image
c = int(cls) # integer class
label = f'{names[c]} {conf:.2f}'
plot_one_box(xyxy, im0, label=label, color=colors[c], line_thickness=3)
if opt.save_crop:
save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)

# Print time (inference + NMS)
print(f'{s}Done. ({t2 - t1:.3f}s)')
Expand Down Expand Up @@ -157,6 +160,7 @@ def detect(save_img=False):
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
Expand Down
32 changes: 20 additions & 12 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.cuda import amp

from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
from utils.plots import color_list, plot_one_box
from utils.torch_utils import time_synchronized

Expand Down Expand Up @@ -311,29 +311,33 @@ def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
self.s = shape # inference BCHW shape

def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
colors = color_list()
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
if pred is not None:
for c in pred[:, -1].unique():
n = (pred[:, -1] == c).sum() # detections per class
str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
if show or save or render:
if show or save or render or crop:
for *box, conf, cls in pred: # xyxy, confidence, class
label = f'{self.names[int(cls)]} {conf:.2f}'
plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
if crop:
save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
else: # all others
plot_one_box(box, im, label=label, color=colors[int(cls) % 10])

im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if pprint:
print(str.rstrip(', '))
if show:
img.show(self.files[i]) # show
im.show(self.files[i]) # show
if save:
f = self.files[i]
img.save(Path(save_dir) / f) # save
im.save(save_dir / f) # save
print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
if render:
self.imgs[i] = np.asarray(img)
self.imgs[i] = np.asarray(im)

def print(self):
self.display(pprint=True) # print results
Expand All @@ -343,10 +347,14 @@ def show(self):
self.display(show=True) # show results

def save(self, save_dir='runs/hub/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir
Path(save_dir).mkdir(parents=True, exist_ok=True)
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
self.display(save=True, save_dir=save_dir) # save results

def crop(self, save_dir='runs/hub/exp'):
save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp', mkdir=True) # increment save_dir
self.display(crop=True, save_dir=save_dir) # crop results
print(f'Saved results to {save_dir}\n')

def render(self):
self.display(render=True) # render results
return self.imgs
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test(data,
device = select_device(opt.device, batch_size=batch_size)

# Directories
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

# Load model
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
def train(hyp, opt, device, tb_writer=None):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
opt.save_dir, opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

# Directories
wdir = save_dir / 'weights'
Expand Down Expand Up @@ -69,7 +69,7 @@ def train(hyp, opt, device, tb_writer=None):
if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
Expand Down Expand Up @@ -577,7 +577,7 @@ def train(hyp, opt, device, tb_writer=None):
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
yaml_file = opt.save_dir / 'hyp_evolved.yaml' # save best result here
if opt.bucket:
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists

Expand Down
27 changes: 21 additions & 6 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):


def apply_classifier(x, model, img, im0):
# applies a second stage classifier to yolo outputs
# Apply a second stage classifier to yolo outputs
im0 = [im0] if isinstance(im0, np.ndarray) else im0
for i, d in enumerate(x): # per image
if d is not None and len(d):
Expand Down Expand Up @@ -591,16 +591,31 @@ def apply_classifier(x, model, img, im0):
return x


def increment_path(path, exist_ok=False, sep=''):
def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False):
# Save an image crop as {file} with crop size multiplied by {gain} and padded by {pad} pixels
xyxy = torch.tensor(xyxy).view(-1, 4)
b = xyxy2xywh(xyxy) # boxes
if square:
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
xyxy = xywh2xyxy(b).long()
clip_coords(xyxy, im.shape)
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2])]
cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop if BGR else crop[..., ::-1])


def increment_path(path, exist_ok=False, sep='', mkdir=False):
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic
if not path.exists() or exist_ok:
return str(path)
else:
if path.exists() and not exist_ok:
suffix = path.suffix
path = path.with_suffix('')
dirs = glob.glob(f"{path}{sep}*") # similar paths
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
i = [int(m.groups()[0]) for m in matches if m] # indices
n = max(i) + 1 if i else 2 # increment number
return f"{path}{sep}{n}{suffix}" # update path
path = Path(f"{path}{sep}{n}{suffix}") # update path
dir = path if path.suffix == '' else path.parent # directory
if not dir.exists() and mkdir:
dir.mkdir(parents=True, exist_ok=True) # make directory
return path

0 comments on commit 0cd5de7

Please sign in to comment.