Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Colors() class #2963

Merged
merged 1 commit into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import plot_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized


Expand All @@ -34,6 +34,7 @@ def detect(opt):
model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride
imgsz = check_img_size(imgsz, s=stride) # check img_size
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16

Expand All @@ -52,10 +53,6 @@ def detect(opt):
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride)

# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]

# Run inference
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
Expand Down Expand Up @@ -112,7 +109,7 @@ def detect(opt):
c = int(cls) # integer class
label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}')

plot_one_box(xyxy, im0, label=label, color=colors[c], line_thickness=opt.line_thickness)
plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=opt.line_thickness)
if opt.save_crop:
save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)

Expand Down
5 changes: 2 additions & 3 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box
from utils.plots import color_list, plot_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import time_synchronized


Expand Down Expand Up @@ -312,7 +312,6 @@ def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
self.s = shape # inference BCHW shape

def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
colors = color_list()
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
if pred is not None:
Expand All @@ -325,7 +324,7 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False
if crop:
save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i])
else: # all others
plot_one_box(box, im, label=label, color=colors[int(cls) % 10])
plot_one_box(box, im, label=label, color=colors(cls))

im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if pprint:
Expand Down
24 changes: 16 additions & 8 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,22 @@
matplotlib.use('Agg') # for writing to files only


def color_list():
# Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
def hex2rgb(h):
class Colors:
# Ultralytics color palette https://ultralytics.com/
def __init__(self):
self.palette = [self.hex2rgb(c) for c in matplotlib.colors.TABLEAU_COLORS.values()]
self.n = len(self.palette)

def __call__(self, i, bgr=False):
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c

@staticmethod
def hex2rgb(h): # rgb order (PIL)
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))

return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949)

colors = Colors() # create instance for 'from utils.plots import colors'


def hist2d(x, y, n=100):
Expand Down Expand Up @@ -137,7 +147,6 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
h = math.ceil(scale_factor * h)
w = math.ceil(scale_factor * w)

colors = color_list() # list of colors
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, img in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
Expand Down Expand Up @@ -168,7 +177,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
boxes[[1, 3]] += block_y
for j, box in enumerate(boxes.T):
cls = int(classes[j])
color = colors[cls % len(colors)]
color = colors(cls)
cls = names[cls] if names else cls
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
Expand Down Expand Up @@ -276,7 +285,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes
colors = color_list()
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])

# seaborn correlogram
Expand All @@ -302,7 +310,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
for cls, *box in labels[:1000]:
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
ax[1].imshow(img)
ax[1].axis('off')

Expand Down