From c9558a9b51547febb03d9c1ca42e2ef0fc15bb31 Mon Sep 17 00:00:00 2001 From: "yizhi.chen" Date: Tue, 14 Jul 2020 13:51:34 +0800 Subject: [PATCH] Add device allocation for loss compute --- train.py | 6 +++--- utils/utils.py | 13 +++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index 7bceda101def..29399e275fd8 100644 --- a/train.py +++ b/train.py @@ -80,7 +80,7 @@ def train(hyp, tb_writer, opt, device): nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes # Remove previous results - if opt.local_rank in [-1, 0]: + if local_rank in [-1, 0]: for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): os.remove(f) @@ -161,7 +161,6 @@ def train(hyp, tb_writer, opt, device): if mixed_precision: model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) - # Scheduler https://arxiv.org/pdf/1812.01187.pdf lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) @@ -405,7 +404,6 @@ def train(hyp, tb_writer, opt, device): if __name__ == '__main__': - check_git_status() parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=300) parser.add_argument('--batch-size', type=int, default=16, help="batch size for all gpus.") @@ -430,6 +428,8 @@ def train(hyp, tb_writer, opt, device): parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.") opt = parser.parse_args() opt.weights = last if opt.resume and not opt.weights else opt.weights + with torch_distributed_zero_first(opt.local_rank): + check_git_status() opt.cfg = check_file(opt.cfg) # check file opt.data = check_file(opt.data) # check file opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) diff --git a/utils/utils.py b/utils/utils.py index 4673fa5628e1..8fa044dba29d 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -432,15 +432,16 @@ def forward(self, pred, true): def compute_loss(p, targets, model): # predictions, targets, model + device = targets.device ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor - lcls, lbox, lobj = ft([0]), ft([0]), ft([0]) + lcls, lbox, lobj = ft([0]).to(device), ft([0]).to(device), ft([0]).to(device) tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets h = model.hyp # hyperparameters red = 'mean' # Loss reduction (sum or mean) # Define criteria - BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red) - BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red) + BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red).to(device) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red).to(device) # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 cp, cn = smooth_BCE(eps=0.0) @@ -456,7 +457,7 @@ def compute_loss(p, targets, model): # predictions, targets, model balance = [1.0, 1.0, 1.0] for i, pi in enumerate(p): # layer index, layer predictions b, a, gj, gi = indices[i] # image, anchor, gridy, gridx - tobj = torch.zeros_like(pi[..., 0]) # target obj + tobj = torch.zeros_like(pi[..., 0]).to(device) # target obj nb = b.shape[0] # number of targets if nb: @@ -466,7 +467,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # GIoU pxy = ps[:, :2].sigmoid() * 2. - 0.5 pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] - pbox = torch.cat((pxy, pwh), 1) # predicted box + pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target) lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss @@ -475,7 +476,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # Class if model.nc > 1: # cls loss (only if multiple classes) - t = torch.full_like(ps[:, 5:], cn) # targets + t = torch.full_like(ps[:, 5:], cn).to(device) # targets t[range(nb), tcls[i]] = cp lcls += BCEcls(ps[:, 5:], t) # BCE