Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colorstr() updates #1909

Merged
merged 6 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from models.experimental import attempt_load
from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr
from utils.loss import compute_loss
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt
Expand Down Expand Up @@ -86,7 +86,8 @@ def test(data,
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True,
prefix=colorstr('test: ' if opt.task == 'test' else 'val: '))[0]

seen = 0
confusion_matrix = ConfusionMatrix(nc=nc)
Expand Down
21 changes: 11 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,9 @@

logger = logging.getLogger(__name__)

try:
import wandb
except ImportError:
wandb = None
logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)")


def train(hyp, opt, device, tb_writer=None, wandb=None):
logger.info(colorstr('blue', 'bold', 'Hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
logger.info(colorstr('Hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

Expand Down Expand Up @@ -189,7 +183,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers,
image_weights=opt.image_weights, quad=opt.quad)
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
Expand All @@ -198,8 +192,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if rank in [-1, 0]:
ema.updates = start_epoch * nb // accumulate # set EMA updates
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True,
rank=-1, world_size=opt.world_size, workers=opt.workers, pad=0.5)[0]
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
pad=0.5, prefix=colorstr('val: '))[0]

if not opt.resume:
labels = np.concatenate(dataset.labels, 0)
Expand Down Expand Up @@ -514,6 +509,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# Train
logger.info(opt)
try:
import wandb
except ImportError:
wandb = None
prefix = colorstr('wandb: ')
logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
if not opt.evolve:
tb_writer = None # init loggers
if opt.global_rank in [-1, 0]:
Expand Down
4 changes: 2 additions & 2 deletions utils/autoanchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def check_anchor_order(m):

def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary
prefix = colorstr('blue', 'bold', 'autoanchor') + ': '
prefix = colorstr('autoanchor: ')
print(f'\n{prefix}Analyzing anchors... ', end='')
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
Expand Down Expand Up @@ -73,7 +73,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
from utils.autoanchor import *; _ = kmean_anchors()
"""
thr = 1. / thr
prefix = colorstr('blue', 'bold', 'autoanchor') + ': '
prefix = colorstr('autoanchor: ')

def metric(k, wh): # compute metrics
r = wh[:, None] / k[None]
Expand Down
54 changes: 27 additions & 27 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def exif_size(img):


def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8, image_weights=False, quad=False):
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
Expand All @@ -67,8 +67,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
single_cls=opt.single_cls,
stride=int(stride),
pad=pad,
rank=rank,
image_weights=image_weights)
image_weights=image_weights,
prefix=prefix)

batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(self, path, img_size=640):
elif os.path.isfile(p):
files = [p] # files
else:
raise Exception('ERROR: %s does not exist' % p)
raise Exception(f'ERROR: {p} does not exist')

images = [x for x in files if x.split('.')[-1].lower() in img_formats]
videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
Expand All @@ -144,8 +144,8 @@ def __init__(self, path, img_size=640):
self.new_video(videos[0]) # new video
else:
self.cap = None
assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
(p, img_formats, vid_formats)
assert self.nf > 0, f'No images or videos found in {p}. ' \
f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'

def __iter__(self):
self.count = 0
Expand All @@ -171,14 +171,14 @@ def __next__(self):
ret_val, img0 = self.cap.read()

self.frame += 1
print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')
print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')

else:
# Read image
self.count += 1
img0 = cv2.imread(path) # BGR
assert img0 is not None, 'Image Not Found ' + path
print('image %g/%g %s: ' % (self.count, self.nf, path), end='')
print(f'image {self.count}/{self.nf} {path}: ', end='')

# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
Expand Down Expand Up @@ -238,9 +238,9 @@ def __next__(self):
break

# Print
assert ret_val, 'Camera Error %s' % self.pipe
assert ret_val, f'Camera Error {self.pipe}'
img_path = 'webcam.jpg'
print('webcam %g: ' % self.count, end='')
print(f'webcam {self.count}: ', end='')

# Padded resize
img = letterbox(img0, new_shape=self.img_size)[0]
Expand Down Expand Up @@ -271,15 +271,15 @@ def __init__(self, sources='streams.txt', img_size=640):
self.sources = [clean_str(x) for x in sources] # clean source names for later
for i, s in enumerate(sources):
# Start the thread to read frames from the video stream
print('%g/%g: %s... ' % (i + 1, n, s), end='')
print(f'{i + 1}/{n}: {s}... ', end='')
cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s)
assert cap.isOpened(), 'Failed to open %s' % s
assert cap.isOpened(), f'Failed to open {s}'
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) % 100
_, self.imgs[i] = cap.read() # guarantee first frame
thread = Thread(target=self.update, args=([i, cap]), daemon=True)
print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
print(f' success ({w}x{h} at {fps:.2f} FPS).')
thread.start()
print('') # newline

Expand Down Expand Up @@ -336,7 +336,7 @@ def img2label_paths(img_paths):

class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
Expand All @@ -358,27 +358,27 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
parent = str(p.parent) + os.sep
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
else:
raise Exception('%s does not exist' % p)
raise Exception(f'{prefix}{p} does not exist')
self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
assert self.img_files, 'No images found'
assert self.img_files, f'{prefix}No images found'
except Exception as e:
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')

# Check cache
self.label_files = img2label_paths(self.img_files) # labels
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') # cached labels
if cache_path.is_file():
cache = torch.load(cache_path) # load
if cache['hash'] != get_hash(self.label_files + self.img_files) or 'results' not in cache: # changed
cache = self.cache_labels(cache_path) # re-cache
cache = self.cache_labels(cache_path, prefix) # re-cache
else:
cache = self.cache_labels(cache_path) # cache
cache = self.cache_labels(cache_path, prefix) # cache

# Display cache
[nf, nm, ne, nc, n] = cache.pop('results') # found, missing, empty, corrupted, total
desc = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=desc, total=n, initial=n)
assert nf > 0 or not augment, f'No labels found in {cache_path}. Can not train without labels. See {help_url}'
tqdm(None, desc=prefix + desc, total=n, initial=n)
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'

# Read cache
cache.pop('hash') # remove hash
Expand Down Expand Up @@ -432,9 +432,9 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes
pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'

def cache_labels(self, path=Path('./labels.cache')):
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes
x = {} # dict
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
Expand Down Expand Up @@ -466,18 +466,18 @@ def cache_labels(self, path=Path('./labels.cache')):
x[im_file] = [l, shape]
except Exception as e:
nc += 1
print('WARNING: Ignoring corrupted image and/or label %s: %s' % (im_file, e))
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')

pbar.desc = f"Scanning '{path.parent / path.stem}' for images and labels... " \
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"

if nf == 0:
print(f'WARNING: No labels found in {path}. See {help_url}')
print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')

x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = [nf, nm, ne, nc, i + 1]
torch.save(x, path) # save for next time
logging.info(f"New cache created: {path}")
logging.info(f'{prefix}New cache created: {path}')
return x

def __len__(self):
Expand Down
5 changes: 2 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):

def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*prefix, string = input # color arguments, string
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
colors = {'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
Expand All @@ -138,8 +138,7 @@ def colorstr(*input):
'end': '\033[0m', # misc
'bold': '\033[1m',
'underline': '\033[4m'}

return ''.join(colors[x] for x in prefix) + f'{string}' + colors['end']
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']


def labels_to_class_weights(labels, nc=80):
Expand Down
2 changes: 1 addition & 1 deletion utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ def plot_study_txt(path='study/', x=None): # from utils.plots import *; plot_st
'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')

ax2.grid()
ax2.set_yticks(np.arange(30, 60, 5))
ax2.set_xlim(0, 30)
ax2.set_ylim(29, 51)
ax2.set_yticks(np.arange(30, 55, 5))
ax2.set_xlabel('GPU Speed (ms/img)')
ax2.set_ylabel('COCO AP val')
ax2.legend(loc='lower right')
Expand Down