Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New CSV Logger #4148

Merged
merged 10 commits into from
Jul 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ data/*
!data/*.sh

results*.txt
results*.csv

# Datasets -------------------------------------------------------------------------------------------------------------
coco/
Expand Down
40 changes: 11 additions & 29 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import time
from copy import deepcopy
from pathlib import Path
from threading import Thread

import math
import numpy as np
Expand All @@ -38,7 +37,7 @@
check_requirements, print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.plots import plot_labels, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
Expand All @@ -61,7 +60,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Directories
w = save_dir / 'weights' # weights dir
w.mkdir(parents=True, exist_ok=True) # make dir
last, best, results_file = w / 'last.pt', w / 'best.pt', save_dir / 'results.txt'
last, best = w / 'last.pt', w / 'best.pt'

# Hyperparameters
if isinstance(hyp, str):
Expand All @@ -88,7 +87,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
loggers = Loggers(save_dir, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
if loggers.wandb and resume:
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict

Expand Down Expand Up @@ -167,10 +166,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
ema.updates = ckpt['updates']

# Results
if ckpt.get('training_results') is not None:
results_file.write_text(ckpt['training_results']) # write results.txt

# Epochs
start_epoch = ckpt['epoch'] + 1
if resume:
Expand Down Expand Up @@ -275,11 +270,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders

mloss = torch.zeros(4, device=device) # mean losses
mloss = torch.zeros(3, device=device) # mean losses
if RANK != -1:
train_loader.sampler.set_epoch(epoch)
pbar = enumerate(train_loader)
LOGGER.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
if RANK in [-1, 0]:
pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad()
Expand Down Expand Up @@ -327,20 +322,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ema.update(model)
last_opt_step = ni

# Print
# Log
if RANK in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s)

# Plot
if plots:
if ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
loggers.on_train_batch_end(ni, model, imgs)
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots)

# end batch ------------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -371,13 +359,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness:
best_fitness = fi
loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi)
loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi)

# Save model
if (not nosave) or (final_epoch and not evolve): # if save
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': deepcopy(de_parallel(model)).half(),
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
Expand All @@ -395,9 +382,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# end training -----------------------------------------------------------------------------------------------------
if RANK in [-1, 0]:
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots:
plot_results(save_dir=save_dir) # save as results.png

if not evolve:
if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
Expand All @@ -411,13 +395,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
save_dir=save_dir,
save_json=True,
plots=False)

# Strip optimizers
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers

loggers.on_train_end(last, best)
loggers.on_train_end(last, best, plots)

torch.cuda.empty_cache()
return results
Expand Down
63 changes: 38 additions & 25 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# YOLOv5 experiment logging utils

import warnings
from threading import Thread

import torch
from torch.utils.tensorboard import SummaryWriter

from utils.general import colorstr, emojis
from utils.loggers.wandb.wandb_utils import WandbLogger
from utils.plots import plot_images, plot_results
from utils.torch_utils import de_parallel

LOGGERS = ('txt', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases

try:
import wandb
Expand All @@ -21,10 +23,8 @@

class Loggers():
# YOLOv5 Loggers class
def __init__(self, save_dir=None, results_file=None, weights=None, opt=None, hyp=None,
data_dict=None, logger=None, include=LOGGERS):
def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, data_dict=None, logger=None, include=LOGGERS):
self.save_dir = save_dir
self.results_file = results_file
self.weights = weights
self.opt = opt
self.hyp = hyp
Expand All @@ -35,7 +35,7 @@ def __init__(self, save_dir=None, results_file=None, weights=None, opt=None, hyp
setattr(self, k, None) # init empty logger dictionary

def start(self):
self.txt = True # always log to txt
self.csv = True # always log to csv

# Message
try:
Expand Down Expand Up @@ -63,15 +63,19 @@ def start(self):

return self

def on_train_batch_end(self, ni, model, imgs):
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
# Callback runs on train batch end
if ni == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
if plots:
if ni == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if ni < 3:
f = self.save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})

def on_train_epoch_end(self, epoch):
# Callback runs on train epoch end
Expand All @@ -89,21 +93,28 @@ def on_val_end(self):
files = sorted(self.save_dir.glob('val*.jpg'))
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})

def on_train_val_end(self, mloss, results, lr, epoch, s, best_fitness, fi):
# Callback runs on validation end during training
vals = list(mloss[:-1]) + list(results) + lr
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi):
# Callback runs on val end during training
vals = list(mloss) + list(results) + lr
keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
if self.txt:
with open(self.results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
x = {k: v for k, v in zip(keys, vals)} # dict

if self.csv:
file = self.save_dir / 'results.csv'
n = len(x) + 1 # number of cols
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header
with open(file, 'a') as f:
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')

if self.tb:
for x, tag in zip(vals, tags):
self.tb.add_scalar(tag, x, epoch) # TensorBoard
for k, v in x.items():
self.tb.add_scalar(k, v, epoch) # TensorBoard

if self.wandb:
self.wandb.log({k: v for k, v in zip(tags, vals)})
self.wandb.log(x)
self.wandb.end_epoch(best_result=best_fitness == fi)

def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
Expand All @@ -112,8 +123,10 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)

def on_train_end(self, last, best):
def on_train_end(self, last, best, plots):
# Callback runs on training end
if plots:
plot_results(dir=self.save_dir) # save results.png
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
if self.wandb:
Expand Down
3 changes: 1 addition & 2 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ def __call__(self, p, targets): # predictions, targets, model
lcls *= self.hyp['cls']
bs = tobj.shape[0] # batch size

loss = lbox + lobj + lcls
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()

def build_targets(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
Expand Down
68 changes: 16 additions & 52 deletions utils/plots.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Plotting utils

import glob
import os
from copy import copy
from pathlib import Path

Expand Down Expand Up @@ -387,63 +385,29 @@ def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)


def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay()
# Plot training 'results*.txt', overlaying train and val losses
s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends
t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1] # number of rows
x = range(start, min(stop, n) if stop else n)
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True)
ax = ax.ravel()
for i in range(5):
for j in [i, i + 5]:
y = results[j, x]
ax[i].plot(x, y, marker='.', label=s[j])
# y_smooth = butter_lowpass_filtfilt(y)
# ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j])

ax[i].set_title(t[i])
ax[i].legend()
ax[i].set_ylabel(f) if i == 0 else None # add filename
fig.savefig(f.replace('.txt', '.png'), dpi=200)


def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
# Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp')
def plot_results(file='', dir=''):
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
save_dir = Path(file).parent if file else Path(dir)
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
ax = ax.ravel()
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95']
if bucket:
# files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
files = ['results%g.txt' % x for x in id]
c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id)
os.system(c)
else:
files = list(Path(save_dir).glob('results*.txt'))
assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir)
files = list(save_dir.glob('results*.csv'))
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
for fi, f in enumerate(files):
try:
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1] # number of rows
x = range(start, min(stop, n) if stop else n)
for i in range(10):
y = results[i, x]
if i in [0, 1, 2, 5, 6, 7]:
y[y == 0] = np.nan # don't show zero loss values
# y /= y[0] # normalize
label = labels[fi] if len(labels) else f.stem
ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
ax[i].set_title(s[i])
# if i in [5, 6, 7]: # share train and val loss y axes
data = pd.read_csv(f)
s = [x.strip() for x in data.columns]
x = data.values[:, 0]
for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
y = data.values[:, j]
# y[y == 0] = np.nan # don't show zero values
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
ax[i].set_title(s[j], fontsize=12)
# if j in [8, 9, 10]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
except Exception as e:
print('Warning: Plotting error for %s; %s' % (f, e))

print(f'Warning: Plotting error for {f}: {e}')
ax[1].legend()
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
fig.savefig(save_dir / 'results.png', dpi=200)


def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
Expand Down
2 changes: 1 addition & 1 deletion val.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def run(data,

# Compute loss
if compute_loss:
loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls
loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls

# Run NMS
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
Expand Down