-
-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model Loss #12
Comments
@okanlv 3 has to do with balancing the objectness loss term that is split into two ( The current version tries to balance them equally by averaging the loss term for If I start training, and debug models.py I see that pred_conf.shape
Out[2]: torch.Size([12, 3, 13, 13])
nM
Out[3]: tensor(47.)
(~mask).sum()
Out[4]: tensor(6037) If I apply the same loss_obj = BCEWithLogitsLoss1(pred_conf[mask], mask[mask].float())
Out[14]: tensor(32.41945, grad_fn=<SumBackward0>)
loss_noobj = 0.5 * BCEWithLogitsLoss1(pred_conf[~mask], mask[~mask].float())
Out[16]: tensor(2089.89233, grad_fn=<MulBackward0>) If I use the current code though for line 167 then you can see that the loss terms are now roughly equal in magnitude, so they both influence the gradient to similar degrees: loss_noobj = 0.5 * nM * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float())
Out[17]: tensor(16.27049, grad_fn=<MulBackward1>) |
@okanlv question 4 is related to issue #3. For some reason Question 5 is valid though. I don't have a good answer. Are there some examples you could point to that show the loss term divided by the batch size? |
@glenn-jocher thanks, Q3 & Q4 makes sense now. --> |
@glenn-jocher I think the cost is divided by the batch size at this line |
@okanlv I was just thinking about this. Currently the loss increases with the number of detected objects But the But I see your link. |
Oh, by the way, Lines 116 to 119 in 1cfde4a
|
@glenn-jocher I am not sure about this. You might be right.
Training speed should not depend on the batch size imo too. You can adjust learning rate to speed up the training. Why do you think dividing by the
Actually |
@glenn-jocher Thanks for this excellent repo. I've been scratching my head trying to figure out how to balance the loss weight properly. Thanks for the explanation. I still have a few questions:
|
@ydixon don't thank me yet. The model works near-perfectly for inference, but there seem to remain issues with training. To answer your questions, the model attempts to replicate yolov3 with darknet, so most all design choices come from there. Definitely read the paper:
|
And when I try your new method (balancing by Line 276 in bd3f617
Then any other anchor boxes will learn for noobjectness, judging by this line where you just flipped the mask. Line 176 in bd3f617
I think according to the paper, any anchor boxes that has iou > 0.5 with ground truth should not learn for any objectness(noobj).
Also, if the goal is the balance the weights between I gonna keep testing. Please let me know your thoughts. |
As for your comment about I will try and experiment with simply size_averaging all the loss terms as well to see if this helps SGD convergence. But to be clear, the current implementation should be mathematically equivalent to size_averaging all loss terms and multiplying them all by Thanks for your experiments and your insights, especially catching the constants mistake! |
@ydixon I investigated more, it appears there is an issue with the size averaging, specific to BCE losses. MSE_sa_true = nn.MSELoss(size_average=True)
MSE_sa_false = nn.MSELoss(size_average=False)
nM
(nM / nB) * MSE_sa_true(x[mask], tx[mask])
(1 / nB) * MSE_sa_false(x[mask], tx[mask])
Out[19]: tensor(47., device='cuda:0')
Out[20]: tensor(0.34477, device='cuda:0', grad_fn=<MulBackward1>)
Out[21]: tensor(0.34477, device='cuda:0', grad_fn=<MulBackward0>) So does CE_sa_true = nn.CrossEntropyLoss(size_average=True)
CE_sa_false = nn.CrossEntropyLoss(size_average=False)
(nM / nB) * CE_sa_true(pred_cls[mask], torch.argmax(tcls, 1))
(1 / nB) * CE_sa_false(pred_cls[mask], torch.argmax(tcls, 1))
Out[25]: tensor(17.19778, device='cuda:0', grad_fn=<MulBackward1>)
Out[26]: tensor(17.19778, device='cuda:0', grad_fn=<MulBackward0>) When I try BCE_sa_true = nn.BCEWithLogitsLoss(size_average=True)
BCE_sa_false = nn.BCEWithLogitsLoss(size_average=False)
(nM / nB) * BCE_sa_true(pred_cls[mask], tcls.float())
(1 / nB) * BCE_sa_false(pred_cls[mask], tcls.float())
Out[27]: tensor(2.71279, device='cuda:0', grad_fn=<MulBackward1>)
Out[28]: tensor(217.02332, device='cuda:0', grad_fn=<MulBackward0>) So
This must be creating a huge training imbalance in my combined loss function. This is a very serious bug. There are only two possible corrections: either all BCE loss terms need to be divided by 80, or none do. I will test both ways for 1 epoch to determine the preferable route. Again thank you very much for bringing this to my attention. |
@ydixon corrections made in commit cf9b4cf to both the loss constants (all = 1.0 now) and the loss terms (all size_averaged before multiplying by Initial results show improved training performance. Training using original yolo width and height terms now converges, so I've updated those terms back to their original yolo selves in commit 5d402ad. TODO: Additional works needs to be done to ignore non-best |
Hey,
Following Section 2.2 of YOLO, I have a few questions about the loss calculation shown at the end of this issue.
We are using
λ coord = 5
from line 156 to line 159. Should we also useλ noobj = .5
in line 167?Why are we multiplying BCELoss with
1.5
in line 160? I have not found any reference to this in the papers.pred_conf
gives us a[batch_size x anchor_number x grid_size x grid_size]
tensor. Assumingbatch_size = 1
,anchor_number=3
andgrid_size = 2
, there are 12 elements in this tensor. IfnM = 3
,pred_conf[~mask]
contains 9 elements, so doesmask[~mask].float()
.BCEWithLogitsLoss1
gives the sum of BCE loss for these 9 elements, whereasBCEWithLogitsLoss2
takes the mean ofBCEWithLogitsLoss1
(i.e. divides it by 9 for our case). Now, my question is why are we multiplyingBCEWithLogitsLoss2
withnM
instead of usingBCEWithLogitsLoss1
(should divide by batch_size too prob.) in line 167? There is no division in Section 2.2 of YOLO. Btw,pred_conf[~mask]
could contain 15k elements normally, so we are practically ignoring the confidence loss in line 167.Similar to 3, we should use
BCEWithLogitsLoss1
(should divide by batch_size too prob.) in line 163. BecauseBCEWithLogitsLoss1(pred_cls[mask], tcls.float()) / BCEWithLogitsLoss2(pred_cls[mask], tcls.float()) = batch_size x nM x number_of_classes
.Why are we not dividing all the losses by the
batch_size
? As thebatch_size
increases, the loss increases too. However, we should minimize the expected loss per sample.yolov3/models.py
Lines 155 to 167 in 9514e74
The text was updated successfully, but these errors were encountered: