Skip to content

Commit

Permalink
Add device allocation for loss compute
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhi.chen committed Jul 14, 2020
1 parent 4f08c69 commit c9558a9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def train(hyp, tb_writer, opt, device):
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes

# Remove previous results
if opt.local_rank in [-1, 0]:
if local_rank in [-1, 0]:
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
os.remove(f)

Expand Down Expand Up @@ -161,7 +161,6 @@ def train(hyp, tb_writer, opt, device):
if mixed_precision:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)


# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
Expand Down Expand Up @@ -405,7 +404,6 @@ def train(hyp, tb_writer, opt, device):


if __name__ == '__main__':
check_git_status()
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help="batch size for all gpus.")
Expand All @@ -430,6 +428,8 @@ def train(hyp, tb_writer, opt, device):
parser.add_argument('--local_rank', type=int, default=-1, help="Extra parameter for DDP implementation. Don't use it manually.")
opt = parser.parse_args()
opt.weights = last if opt.resume and not opt.weights else opt.weights
with torch_distributed_zero_first(opt.local_rank):
check_git_status()
opt.cfg = check_file(opt.cfg) # check file
opt.data = check_file(opt.data) # check file
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
Expand Down
13 changes: 7 additions & 6 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,15 +432,16 @@ def forward(self, pred, true):


def compute_loss(p, targets, model): # predictions, targets, model
device = targets.device
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
lcls, lbox, lobj = ft([0]).to(device), ft([0]).to(device), ft([0]).to(device)
tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
h = model.hyp # hyperparameters
red = 'mean' # Loss reduction (sum or mean)

# Define criteria
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red)
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red)
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red).to(device)
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red).to(device)

# class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
cp, cn = smooth_BCE(eps=0.0)
Expand All @@ -456,7 +457,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
balance = [1.0, 1.0, 1.0]
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros_like(pi[..., 0]) # target obj
tobj = torch.zeros_like(pi[..., 0]).to(device) # target obj

nb = b.shape[0] # number of targets
if nb:
Expand All @@ -466,7 +467,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
# GIoU
pxy = ps[:, :2].sigmoid() * 2. - 0.5
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1) # predicted box
pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou(prediction, target)
lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss

Expand All @@ -475,7 +476,7 @@ def compute_loss(p, targets, model): # predictions, targets, model

# Class
if model.nc > 1: # cls loss (only if multiple classes)
t = torch.full_like(ps[:, 5:], cn) # targets
t = torch.full_like(ps[:, 5:], cn).to(device) # targets
t[range(nb), tcls[i]] = cp
lcls += BCEcls(ps[:, 5:], t) # BCE

Expand Down

0 comments on commit c9558a9

Please sign in to comment.