Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Loss #12

Closed
okanlv opened this issue Sep 12, 2018 · 16 comments
Closed

Model Loss #12

okanlv opened this issue Sep 12, 2018 · 16 comments
Assignees
Labels
bug Something isn't working

Comments

@okanlv
Copy link

okanlv commented Sep 12, 2018

Hey,
Following Section 2.2 of YOLO, I have a few questions about the loss calculation shown at the end of this issue.

  1. We are using λ coord = 5 from line 156 to line 159. Should we also use λ noobj = .5 in line 167?

  2. Why are we multiplying BCELoss with 1.5 in line 160? I have not found any reference to this in the papers.

  3. pred_conf gives us a [batch_size x anchor_number x grid_size x grid_size] tensor. Assuming batch_size = 1, anchor_number=3 and grid_size = 2, there are 12 elements in this tensor. If nM = 3, pred_conf[~mask] contains 9 elements, so does mask[~mask].float(). BCEWithLogitsLoss1 gives the sum of BCE loss for these 9 elements, whereas BCEWithLogitsLoss2 takes the mean of BCEWithLogitsLoss1 (i.e. divides it by 9 for our case). Now, my question is why are we multiplying BCEWithLogitsLoss2 with nM instead of using BCEWithLogitsLoss1 (should divide by batch_size too prob.) in line 167? There is no division in Section 2.2 of YOLO. Btw, pred_conf[~mask] could contain 15k elements normally, so we are practically ignoring the confidence loss in line 167.

  4. Similar to 3, we should use BCEWithLogitsLoss1 (should divide by batch_size too prob.) in line 163. Because
    BCEWithLogitsLoss1(pred_cls[mask], tcls.float()) / BCEWithLogitsLoss2(pred_cls[mask], tcls.float()) = batch_size x nM x number_of_classes.

  5. Why are we not dividing all the losses by the batch_size? As the batch_size increases, the loss increases too. However, we should minimize the expected loss per sample.

yolov3/models.py

Lines 155 to 167 in 9514e74

if nM > 0:
lx = 5 * MSELoss(x[mask], tx[mask])
ly = 5 * MSELoss(y[mask], ty[mask])
lw = 5 * MSELoss(w[mask], tw[mask])
lh = 5 * MSELoss(h[mask], th[mask])
lconf = 1.5 * BCEWithLogitsLoss1(pred_conf[mask], mask[mask].float())
# lcls = nM * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
lcls = nM * BCEWithLogitsLoss2(pred_cls[mask], tcls.float())
else:
lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
lconf += nM * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float())

@glenn-jocher
Copy link
Member

@okanlv 1 and 2 yes you are right there is a bug here. For some reason I was using lambda_obj=1.5 and lambda_noobj=1.0 rather than lambda_obj=1.0 and lambda_noobj=0.5. I've corrected this in commit 68de92f.

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 13, 2018

@okanlv 3 has to do with balancing the objectness loss term that is split into two (obj and noobj). noobj will always contain more samples than obj, so this imbalance would cause the training loss to always favor noobj (i.e. no objects would ever be detected) if we leave it alone.

The current version tries to balance them equally by averaging the loss term for noobj multiplied by target count nM. obj is not averaged, but since it consists of nM samples this should help the two terms attain equal weight. I'll try and follow your steps here.

If I start training, and debug models.py I see that pred_conf shape [nB, nA, nG, nG], target count nM, and background count used in noobj loss term are:

pred_conf.shape
Out[2]: torch.Size([12, 3, 13, 13])

nM
Out[3]: tensor(47.)

(~mask).sum()
Out[4]: tensor(6037)

If I apply the same BCEWithLogitsLoss1 term to both objectness lines 160 and 167 I see imbalanced losses for the two terms naturally, since the first has 47 samples and the second has 6037 samples. This will cause loss_noobj to dominate the gradient, and loss_obj will be practically ignored for training purposes, which is bad.

loss_obj = BCEWithLogitsLoss1(pred_conf[mask], mask[mask].float())
Out[14]: tensor(32.41945, grad_fn=<SumBackward0>)

loss_noobj = 0.5 * BCEWithLogitsLoss1(pred_conf[~mask], mask[~mask].float())
Out[16]: tensor(2089.89233, grad_fn=<MulBackward0>)

If I use the current code though for line 167 then you can see that the loss terms are now roughly equal in magnitude, so they both influence the gradient to similar degrees:

loss_noobj = 0.5 * nM * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float())
Out[17]: tensor(16.27049, grad_fn=<MulBackward1>)

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 13, 2018

@okanlv question 4 is related to issue #3. For some reason nn.CrossEntropyLoss() seems to produce better results during training than nn.BCEWithLogitsLoss(). I don't know why this is the case, as YOLOv3 officially uses CE (though YOLOv2 did use BCE). I've observed similar improvements when training MNIST in PyTorch using CE vs BCE, so it may be a PyTorch issue.

Question 5 is valid though. I don't have a good answer. Are there some examples you could point to that show the loss term divided by the batch size?

@okanlv
Copy link
Author

okanlv commented Sep 14, 2018

@glenn-jocher thanks, Q3 & Q4 makes sense now.
Q5 -> I have not have a good reference in original YOLO code, but this function returns the loss term. In addition, we should probably divide the loss term by the number of detected objects (nM) too since it is undesirable to scale the loss by the number of detected objects.

--> loss_noobj = 0.5 * nM * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float()) gives 0 loss when nM = 0. We could use loss_noobj = 0.5 * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float()) instead when nM = 0.

@okanlv
Copy link
Author

okanlv commented Sep 19, 2018

@glenn-jocher I think the cost is divided by the batch size at this line

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 19, 2018

@okanlv I was just thinking about this. Currently the loss increases with the number of detected objects nM, and the batch_size. I think increasing with nM is ok, because it makes sense that an image with 10 objects would influence the gradient more than an image with 1 object, no?

But the batch_size itself I'm not sure. If we divide by the batch_size, larger batches will take longer to train the model. i.e. batch_size=16 will train twice as slow with same size loss as batch_size=8. Currently loss increases with batch_size, so training speed would be similar I think irrespective of batch_size.

But I see your link. n looks like the batch size there, and the code divides loss by n. Do you think I should leave nM in place and divide by the batch size then?

@glenn-jocher
Copy link
Member

Oh, by the way, nM would probably never be zero, because the loop in train.py currently skips batches that don't have any targets (though nM does get set 3 times per batch, once for each YOLO stride):

yolov3/train.py

Lines 116 to 119 in 1cfde4a

for i, (imgs, targets) in enumerate(dataloader):
if sum([len(x) for x in targets]) < 1: # if no targets continue
continue

@okanlv
Copy link
Author

okanlv commented Sep 20, 2018

I was just thinking about this. Currently the loss increases with the number of detected objects nM, and the batch_size. I think increasing with nM is ok, because it makes sense that an image with 10 objects would influence the gradient more than an image with 1 object, no?

@glenn-jocher I am not sure about this. You might be right.

But the batch_size itself I'm not sure. If we divide by the batch_size, larger batches will take longer to train the model. i.e. batch_size=16 will train twice as slow with same size loss as batch_size=8. Currently loss increases with batch_size, so training speed would be similar I think irrespective of batch_size.

Training speed should not depend on the batch size imo too. You can adjust learning rate to speed up the training. Why do you think dividing by the batch_size affects the training speed? If we divide by batch_size, the loss will stay the same (on average) regardless of the batch_size. If you are talking about the speed change because of total_samples/batch_size, it is normal. For instance, in the current loss calculation if the batch size is too large, then the loss might not converge.

But I see your link. n looks like the batch size there, and the code divides loss by n. Do you think I should leave nM in place and divide by the batch size then?

Actually n = subdivisions * ngpus = and batch = batch_size. So, you should divide the loss only by the batch size as you suggested.

@glenn-jocher
Copy link
Member

Ok, I've switched from Adam to SGD with burn-in (which exponentially ramps up the learning rate from 0 to 0.001 over the first 1000 iterations) in commit a722601. I included the loss division by the batch size as well:

loss = (lx + ly + lw + lh + lconf + lcls) / batch_size

@ydixon
Copy link

ydixon commented Sep 21, 2018

@glenn-jocher Thanks for this excellent repo. I've been scratching my head trying to figure out how to balance the loss weight properly. Thanks for the explanation. I still have a few questions:

  1. Why are we doing mseloss on sigmoid(tx) instead of tx?
  2. How did you come up with power method for w,h? Is it just by experimentation?
  3. How are the loss weight values (5, 0.5) obtained? Trial and error?
  4. Any idea why doing bceloss(size_average=True) on confidence and cls loss wouldn't work? (I tried and it didn't converge)

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 21, 2018

@ydixon don't thank me yet. The model works near-perfectly for inference, but there seem to remain issues with training. To answer your questions, the model attempts to replicate yolov3 with darknet, so most all design choices come from there. Definitely read the paper:
https://pjreddie.com/media/files/papers/YOLOv3.pdf

  1. Why are we doing mseloss on sigmoid(tx) instead of tx?
    The box x and y values are defined in the paper in anchor space (i.e. 0-13 across the image width) as sigmoid(tx) + anchor location (0-13). This way an anchor at location 5 for example can vary from 5-6, but not outside this range. Current losses are computed from these box locations in anchor space... are you saying we could try to compute the losses prior to sigmoiding (we would need to logit the target values then)?

  2. How did you come up with power method for w,h? Is it just by experimentation?
    Sort of by experimenting, but I was constrained that I needed output=1 from input=0 (i.e. f(0)=1) to match existing output there. Both the exp and power methods of width and height do this well, but the power method keeps the output from exploding at the high end. I only placed the power method there because SGD with burn-in would not converge using the normal method, hinting at a problem somewhere in the training code. The power method approaches 4 at the positive limit.

  1. How are the loss weight values (5, 0.5) obtained? Trial and error?
    No. These are specified in the YOLO paper. lambda_noobj = 0.5 and lambda_coord = 5.

  2. Any idea why doing bceloss(size_average=True) on confidence and cls loss wouldn't work? (I tried and it didn't converge).
    The obj and noobj halves of the conf loss term need to be balanced, since noobj will always have many more samples than obj (most anchors will be empty). The current method does this by size-averaging noobj and multiplying it by nM, the number of samples in obj. This comment has a more detailed explanation: Model Loss #12 (comment). The alternative is to size-average both conf terms and them divide noobj by nM.

@okanlv okanlv closed this as completed Sep 22, 2018
@ydixon
Copy link

ydixon commented Sep 22, 2018

@glenn-jocher

  1. That's what I picked up from the paper. In fact, I already tried it and the model still converges, but I haven't implement metrics to see which method performs better.
  2. Need some more time to digest the paragraph. In my own implementation, the yolo method still converged.
  3. I think lambda scales are removed when moving from YoloV2 to YoloV3. I'm not totally sure since I haven't fully read the C code yet.
  4. Here's few things that i experimented.
mseloss(average=false)@x,y,w,h + bceloss(average=false)@conf,cls = converge
mseloss(average=false)@x,y,w,h + bceloss(average=True)@conf,cls = doesn't converge
bceloss(average=True)@x,y,w,h + bceloss(average=True)@conf,cls = converge but poor performance

And when I try your new method (balancing by nM), it didn't converged as I hope it will be. It could be because we are creating masks in different ways. Let me know if I don't understand correctly. In your implementation, you set your mask by finding the best anchor and setting to 1, that anchor box will learn for objectness.

tconf[b, a, gj, gi] = 1

Then any other anchor boxes will learn for noobjectness, judging by this line where you just flipped the mask.
lconf += (0.5 * nM / nB) * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float())

I think according to the paper, any anchor boxes that has iou > 0.5 with ground truth should not learn for any objectness(noobj).

Also, if the goal is the balance the weights between obj and noobj, isn't that achieved by performing size-average on both term? It seems like this is more about balancing confidence loss and the coordinates loss instead of obj and noobj losses . Also, I think in the C version, the model could still learn something when fed with samples that has no labels. So having them balanced is not necessary?

I gonna keep testing. Please let me know your thoughts.

@glenn-jocher glenn-jocher reopened this Sep 23, 2018
@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 23, 2018

@ydixon

  1. I see. How are you turning the target values into logits when you compute losses on the pre-sigmoid network outputs? If part of the total loss is computed on logits and part on sigmoided values, wouldn't the logit losses be much higher than their sigmoid counterparts, since they are unbounded while the sigmoid losses can never extend past the -1 to 1 range for each sample?

  2. I actually don't want to use this, I was just forced into it to get SGD to converge, even with the 1000 iteration burn in. Adam always converges, regardless of the width/height loss terms. I'm confused why, I'm worried this is pointing to issues elsewhere in training (but not inference).

  3. I think you are right! The constants used to be in the .cfg files, but now they seem to be hard coded into the parser.c (as of V3). It seems I am using the old V1 constants. I've opened issue Loss Constants: _coord, _obj and _noobj #17 on this, please post there for this specific item.
    https://github.com/pjreddie/darknet/blob/680d3bde1924c8ee2d1c1dea54d3e56a05ca9a26/src/parser.c#L376-L381

  4. Size_average is an alternative of course, but I was under the assumption that darknet does not do this, it sums the losses and divides by the batch_size. Except that noobj has many more samples than all the other loss terms (i.e. 1000x as many samples), so summing this term is inappropriate as it would dominate the gradient. It would still converge, but the trained model would never detect any objects, as it had favored noobj during training.

As for your comment about 0.5 iou, as I understand it any anchor that has >0.5 iou with the nearest target, but is not the best anchor, should not be penalized by noobj (it is simply ignored, i.e. not part of obj nor part of noobj). I had implemented this in the past but saw no significant differences with or without it. I can place it back in for correctness though.

I will try and experiment with simply size_averaging all the loss terms as well to see if this helps SGD convergence. But to be clear, the current implementation should be mathematically equivalent to size_averaging all loss terms and multiplying them all by nM, so I'd be surprised to see any differences, as the change would simply divide the current loss by nM.

Thanks for your experiments and your insights, especially catching the constants mistake!

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 23, 2018

@ydixon I investigated more, it appears there is an issue with the size averaging, specific to BCE losses. nn.MSELoss behaves as expected, it size averages by nM = 47, the number of samples in the first training batch:

MSE_sa_true = nn.MSELoss(size_average=True)
MSE_sa_false = nn.MSELoss(size_average=False)

nM
(nM / nB) * MSE_sa_true(x[mask], tx[mask])
(1 / nB) * MSE_sa_false(x[mask], tx[mask])

Out[19]: tensor(47., device='cuda:0')
Out[20]: tensor(0.34477, device='cuda:0', grad_fn=<MulBackward1>)
Out[21]: tensor(0.34477, device='cuda:0', grad_fn=<MulBackward0>)

So does nn.CrossEntropyLoss():

CE_sa_true = nn.CrossEntropyLoss(size_average=True)
CE_sa_false = nn.CrossEntropyLoss(size_average=False)

(nM / nB) * CE_sa_true(pred_cls[mask], torch.argmax(tcls, 1))
(1 / nB) * CE_sa_false(pred_cls[mask], torch.argmax(tcls, 1))

Out[25]: tensor(17.19778, device='cuda:0', grad_fn=<MulBackward1>)
Out[26]: tensor(17.19778, device='cuda:0', grad_fn=<MulBackward0>)

When I try nn.BCEWithLogitsLoss the second output is 80 times larger than the first, 80 being the number of COCO object classes.

BCE_sa_true = nn.BCEWithLogitsLoss(size_average=True)
BCE_sa_false = nn.BCEWithLogitsLoss(size_average=False)

(nM / nB) * BCE_sa_true(pred_cls[mask], tcls.float())
(1 / nB) * BCE_sa_false(pred_cls[mask], tcls.float())

Out[27]: tensor(2.71279, device='cuda:0', grad_fn=<MulBackward1>)
Out[28]: tensor(217.02332, device='cuda:0', grad_fn=<MulBackward0>)

So nn.BCEWithLogitsLoss is size averaging across all of its elements rather than just its rows:

pred_cls[mask].shape
Out[5]: torch.Size([47, 80])

This must be creating a huge training imbalance in my combined loss function. This is a very serious bug. There are only two possible corrections: either all BCE loss terms need to be divided by 80, or none do. I will test both ways for 1 epoch to determine the preferable route. Again thank you very much for bringing this to my attention.

@glenn-jocher glenn-jocher added the bug Something isn't working label Sep 23, 2018
@glenn-jocher glenn-jocher self-assigned this Sep 23, 2018
@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 23, 2018

@ydixon corrections made in commit cf9b4cf to both the loss constants (all = 1.0 now) and the loss terms (all size_averaged before multiplying by nM / nB).

Initial results show improved training performance. Training using original yolo width and height terms now converges, so I've updated those terms back to their original yolo selves in commit 5d402ad.

TODO: Additional works needs to be done to ignore non-best anchors > 0.50 iou.

@glenn-jocher
Copy link
Member

@okanlv it seems we almost on par now with darknet training. See #310.

Closing this issue as it seems resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants