Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evolve in CSV format #4307

Merged
merged 12 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ coco
storage.googleapis.com

data/samples/*
**/results*.txt
**/results*.csv
*.jpg

# Neural Network weights -----------------------------------------------------------------------------------------------
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ data/*
!data/images/bus.jpg
!data/*.sh

results*.txt
results*.csv

# Datasets -------------------------------------------------------------------------------------------------------------
Expand Down
32 changes: 18 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolution
from utils.plots import plot_labels, plot_evolve
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
Expand Down Expand Up @@ -367,7 +367,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness:
best_fitness = fi
callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi)
log_vals = list(mloss) + list(results) + lr
callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi)

# Save model
if (not nosave) or (final_epoch and not evolve): # if save
Expand Down Expand Up @@ -464,7 +465,7 @@ def main(opt):
check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=['thop'])

# Resume
if opt.resume and not check_wandb_resume(opt): # resume an interrupted run
if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
Expand All @@ -474,8 +475,10 @@ def main(opt):
else:
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
opt.name = 'evolve' if opt.evolve else opt.name
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
if opt.evolve:
opt.project = 'runs/evolve'
opt.exist_ok = opt.resume
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))

# DDP mode
device = select_device(opt.device, batch_size=opt.batch_size)
Expand Down Expand Up @@ -533,17 +536,17 @@ def main(opt):
hyp = yaml.safe_load(f) # load hyps dict
if 'anchors' not in hyp: # anchors commented in hyp.yaml
hyp['anchors'] = 3
opt.noval, opt.nosave = True, True # only val/save final epoch
opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
if opt.bucket:
os.system(f'gsutil cp gs://{opt.bucket}/evolve.txt .') # download evolve.txt if exists
os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists

for _ in range(opt.evolve): # generations to evolve
if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
# Select parent(s)
parent = 'single' # parent selection method: 'single' or 'weighted'
x = np.loadtxt('evolve.txt', ndmin=2)
x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
n = min(5, len(x)) # number of previous results to consider
x = x[np.argsort(-fitness(x))][:n] # top n mutations
w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
Expand Down Expand Up @@ -575,12 +578,13 @@ def main(opt):
results = train(hyp.copy(), opt, device)

# Write mutation results
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
print_mutation(results, hyp.copy(), save_dir, opt.bucket)

# Plot results
plot_evolution(yaml_file)
print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
plot_evolve(evolve_csv)
print(f'Hyperparameter evolution finished\n'
f"Results saved to {colorstr('bold', save_dir)}"
f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')


def run(**kwargs):
Expand Down
50 changes: 29 additions & 21 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,35 +615,43 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")


def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
# Print mutation results to evolve.txt (for use with train.py --evolve)
a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
def print_mutation(results, hyp, save_dir, bucket):
evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
keys = tuple(x.strip() for x in keys)
vals = results + tuple(hyp.values())
n = len(keys)

# Download (optional)
if bucket:
url = 'gs://%s/evolve.txt' % bucket
if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
url = f'gs://{bucket}/evolve.csv'
if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local

# Log to evolve.csv
s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
with open(evolve_csv, 'a') as f:
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')

with open('evolve.txt', 'a') as f: # append result
f.write(c + b + '\n')
x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
x = x[np.argsort(-fitness(x))] # sort
np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
# Print to screen
print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')

# Save yaml
for i, k in enumerate(hyp.keys()):
hyp[k] = float(x[0, i + 7])
with open(yaml_file, 'w') as f:
results = tuple(x[0, :7])
c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
with open(evolve_yaml, 'w') as f:
data = pd.read_csv(evolve_csv)
data = data.rename(columns=lambda x: x.strip()) # strip keys
i = np.argmax(fitness(data.values[:, :7])) #
f.write(f'# YOLOv5 Hyperparameter Evolution Results\n' +
f'# Best generation: {i}\n' +
f'# Last generation: {len(data)}\n' +
f'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
f'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
yaml.safe_dump(hyp, f, sort_keys=False)

if bucket:
os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload


def apply_classifier(x, model, img, im0):
Expand Down
5 changes: 2 additions & 3 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ def on_val_end(self):
files = sorted(self.save_dir.glob('val*.jpg'))
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})

def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi):
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
# Callback runs at the end of each fit (train+val) epoch
vals = list(mloss) + list(results) + lr
x = {k: v for k, v in zip(self.keys, vals)} # dict
if self.csv:
file = self.save_dir / 'results.csv'
Expand All @@ -123,7 +122,7 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
def on_train_end(self, last, best, plots, epoch):
# Callback runs on training end
if plots:
plot_results(dir=self.save_dir) # save results.png
plot_results(file=self.save_dir / 'results.csv') # save results.png
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter

Expand Down
50 changes: 25 additions & 25 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,30 +325,6 @@ def plot_labels(labels, names=(), save_dir=Path('')):
plt.close()


def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
# Plot hyperparameter evolution results in evolve.txt
with open(yaml_file) as f:
hyp = yaml.safe_load(f)
x = np.loadtxt('evolve.txt', ndmin=2)
f = fitness(x)
# weights = (f - f.min()) ** 2 # for weighted results
plt.figure(figsize=(10, 12), tight_layout=True)
matplotlib.rc('font', **{'size': 8})
for i, (k, v) in enumerate(hyp.items()):
y = x[:, i + 7]
# mu = (y * weights).sum() / weights.sum() # best weighted result
mu = y[f.argmax()] # best single result
plt.subplot(6, 5, i + 1)
plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
plt.plot(mu, f.max(), 'k+', markersize=15)
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
if i % 5 != 0:
plt.yticks([])
print('%15s: %.3g' % (k, mu))
plt.savefig('evolve.png', dpi=200)
print('\nPlot saved as evolve.png')


def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
# Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
Expand Down Expand Up @@ -381,7 +357,31 @@ def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)


def plot_results(file='', dir=''):
def plot_evolve(evolve_csv=Path('path/to/evolve.csv')): # from utils.plots import *; plot_evolve()
# Plot evolve.csv hyp evolution results
data = pd.read_csv(evolve_csv)
keys = [x.strip() for x in data.columns]
x = data.values
f = fitness(x)
j = np.argmax(f) # max fitness index
plt.figure(figsize=(10, 12), tight_layout=True)
matplotlib.rc('font', **{'size': 8})
for i, k in enumerate(keys[7:]):
v = x[:, 7 + i]
mu = v[j] # best single result
plt.subplot(6, 5, i + 1)
plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
plt.plot(mu, f.max(), 'k+', markersize=15)
plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
if i % 5 != 0:
plt.yticks([])
print('%15s: %.3g' % (k, mu))
f = evolve_csv.with_suffix('.png') # filename
plt.savefig(f, dpi=200)
print(f'Saved {f}')


def plot_results(file='path/to/results.csv', dir=''):
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
save_dir = Path(file).parent if file else Path(dir)
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
Expand Down