diff --git a/test.py b/test.py index 9f484c809052..91176eca01db 100644 --- a/test.py +++ b/test.py @@ -269,6 +269,7 @@ def test(data, print(f'pycocotools unable to run: {e}') # Return results + model.float() # for training if not training: s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' print(f"Results saved to {save_dir}{s}") diff --git a/train.py b/train.py index 7aa57fa99e24..e37cf816bcb1 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ import os import random import time +from copy import deepcopy from pathlib import Path from threading import Thread @@ -381,8 +382,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ckpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': results_file.read_text(), - 'model': (model.module if is_parallel(model) else model).half(), - 'ema': (ema.ema.half(), ema.updates), + 'model': deepcopy(model.module if is_parallel(model) else model).half(), + 'ema': (deepcopy(ema.ema).half(), ema.updates), 'optimizer': optimizer.state_dict(), 'wandb_id': wandb_run.id if wandb else None} @@ -392,8 +393,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): torch.save(ckpt, best) del ckpt - model.float(), ema.ema.float() - # end epoch ---------------------------------------------------------------------------------------------------- # end training