diff --git a/train.py b/train.py index bbf879f3af5f..5c203f12651d 100644 --- a/train.py +++ b/train.py @@ -383,7 +383,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ckpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': results_file.read_text(), - 'model': deepcopy(model.module if is_parallel(model) else model).half(), + 'model': ema.ema if final_epoch else deepcopy( + model.module if is_parallel(model) else model).half(), 'ema': (deepcopy(ema.ema).half(), ema.updates), 'optimizer': optimizer.state_dict(), 'wandb_id': wandb_run.id if wandb else None}