From 24380e6b43638c9d803d71ef42b7a972f0d94706 Mon Sep 17 00:00:00 2001 From: bilzard <36561962+bilzard@users.noreply.github.com> Date: Sat, 5 Feb 2022 02:58:18 +0900 Subject: [PATCH] Load checkpoint on CPU instead of on GPU (#6516) * Load checkpoint on CPU instead of on GPU * refactor: simplify code * Cleanup * Update train.py Co-authored-by: Glenn Jocher --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 2a973fb7164b..56103b8d4202 100644 --- a/train.py +++ b/train.py @@ -120,7 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally - ckpt = torch.load(weights, map_location=device) # load checkpoint + ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32