diff --git a/train.py b/train.py index 5cbf6809d1d9..e7d303a510f5 100644 --- a/train.py +++ b/train.py @@ -32,7 +32,7 @@ from utils.google_utils import attempt_download from utils.loss import ComputeLoss from utils.plots import plot_images, plot_labels, plot_results, plot_evolution -from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel +from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume @@ -367,7 +367,7 @@ def train(hyp, opt, device, tb_writer=None): f = save_dir / f'train_batch{ni}.jpg' # filename Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() if tb_writer: - tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph + tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) elif plots and ni == 10 and wandb_logger.wandb: wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in @@ -442,7 +442,7 @@ def train(hyp, opt, device, tb_writer=None): ckpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': results_file.read_text(), - 'model': deepcopy(model.module if is_parallel(model) else model).half(), + 'model': deepcopy(de_parallel(model)).half(), 'ema': deepcopy(ema.ema).half(), 'updates': ema.updates, 'optimizer': optimizer.state_dict(), diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 5074fa95ae4b..aa54c3cf561e 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -134,9 +134,15 @@ def profile(x, ops, n=100, device=None): def is_parallel(model): + # Returns True if model is of type DP or DDP return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) +def de_parallel(model): + # De-parallelize a model: returns single-GPU model if model is of type DP or DDP + return model.module if is_parallel(model) else model + + def intersect_dicts(da, db, exclude=()): # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}