Skip to content

Commit

Permalink
Update train.py (ultralytics#2290)
Browse files Browse the repository at this point in the history
* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Create train.py
  • Loading branch information
glenn-jocher authored Feb 24, 2021
1 parent 1b766b2 commit 2b6949b
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# Results
if ckpt.get('training_results') is not None:
with open(results_file, 'w') as file:
file.write(ckpt['training_results']) # write results.txt
results_file.write_text(ckpt['training_results']) # write results.txt

# Epochs
start_epoch = ckpt['epoch'] + 1
Expand Down Expand Up @@ -354,7 +353,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
if len(opt.name) and opt.bucket:
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))

Expand All @@ -375,15 +374,13 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
best_fitness = fi

# Save model
save = (not opt.nosave) or (final_epoch and not opt.evolve)
if save:
with open(results_file, 'r') as f: # create checkpoint
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': f.read(),
'model': ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}

# Save last, best and delete
torch.save(ckpt, last)
Expand All @@ -396,9 +393,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if rank in [-1, 0]:
# Strip optimizers
final = best if best.exists() else last # final model
for f in [last, best]:
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
strip_optimizer(f)
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload

Expand All @@ -415,17 +412,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Test best.pt
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests
for m in (last, best) if best.exists() else (last): # speed, mAP tests
results, _, _ = test.test(opt.data,
batch_size=batch_size * 2,
imgsz=imgsz_test,
conf_thres=conf,
iou_thres=iou,
model=attempt_load(final, device).half(),
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=save_json,
save_json=True,
plots=False)

else:
Expand Down

0 comments on commit 2b6949b

Please sign in to comment.