From 8d7b2d12d0c09a8a8bbdb99b5c9a20bb8bdbb55b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 24 Feb 2021 21:03:21 -0800 Subject: [PATCH] Improved model+EMA checkpointing 2 (#2295) --- test.py | 1 + train.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test.py b/test.py index 9f484c809052..91176eca01db 100644 --- a/test.py +++ b/test.py @@ -269,6 +269,7 @@ def test(data, print(f'pycocotools unable to run: {e}') # Return results + model.float() # for training if not training: s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' print(f"Results saved to {save_dir}{s}") diff --git a/train.py b/train.py index 7aa57fa99e24..e37cf816bcb1 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ import os import random import time +from copy import deepcopy from pathlib import Path from threading import Thread @@ -381,8 +382,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ckpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': results_file.read_text(), - 'model': (model.module if is_parallel(model) else model).half(), - 'ema': (ema.ema.half(), ema.updates), + 'model': deepcopy(model.module if is_parallel(model) else model).half(), + 'ema': (deepcopy(ema.ema).half(), ema.updates), 'optimizer': optimizer.state_dict(), 'wandb_id': wandb_run.id if wandb else None} @@ -392,8 +393,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): torch.save(ckpt, best) del ckpt - model.float(), ema.ema.float() - # end epoch ---------------------------------------------------------------------------------------------------- # end training