Skip to content

Commit

Permalink
Fix segment evolution keys (#9742)
Browse files Browse the repository at this point in the history
* Update

* Cleanup
  • Loading branch information
glenn-jocher committed Oct 9, 2022
1 parent 5ef69ef commit 209be93
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion segment/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def main(opt, callbacks=Callbacks()):
results = train(hyp.copy(), opt, device, callbacks)
callbacks = Callbacks()
# Write mutation results
print_mutation(results, hyp.copy(), save_dir, opt.bucket)
print_mutation(KEYS, results, hyp.copy(), save_dir, opt.bucket)

# Plot results
plot_evolve(evolve_csv)
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,9 @@ def main(opt, callbacks=Callbacks()):
results = train(hyp.copy(), opt, device, callbacks)
callbacks = Callbacks()
# Write mutation results
print_mutation(results, hyp.copy(), save_dir, opt.bucket)
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
'val/obj_loss', 'val/cls_loss')
print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket)

# Plot results
plot_evolve(evolve_csv)
Expand Down
5 changes: 2 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,11 +957,10 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")


def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
evolve_csv = save_dir / 'evolve.csv'
evolve_yaml = save_dir / 'hyp_evolve.yaml'
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
keys = tuple(x.strip() for x in keys)
vals = results + tuple(hyp.values())
n = len(keys)
Expand Down

0 comments on commit 209be93

Please sign in to comment.