Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement break-down mode for cross entropy calculation #1257

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions yolox/models/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,12 +487,30 @@ def get_assignments(

with torch.cuda.amp.autocast(enabled=False):
cls_preds_ = (
cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
* obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
)
pair_wise_cls_loss = F.binary_cross_entropy(
cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
).sum(-1)
cls_preds_.float().unsqueeze(0).sigmoid_()
* obj_preds_.float().unsqueeze(0).sigmoid_()
).sqrt_()

try:
# may fail to allocate GPU memory when num_gt is large
pair_wise_cls_loss = F.binary_cross_entropy(
cls_preds_.repeat(num_gt, 1, 1), gt_cls_per_image,
reduction="none"
).sum(-1)
except RuntimeError as e:
# TODO: the string might change, consider a better way
if mode == "cpu" or "CUDA out of memory. " not in str(e):
raise # RuntimeError might not caused by CUDA OOM
# to work with less GPU memory
print('--- break-down mode for cross-entropy calculations ---')
pair_wise_cls_loss = torch.empty(
(num_gt, torch.sum(fg_mask)),
dtype=cls_preds_.dtype, device=cls_preds_.device)
for i in range(num_gt):
pair_wise_cls_loss[i] = F.binary_cross_entropy(
cls_preds_, gt_cls_per_image[i].unsqueeze(0),
reduction="none"
).sum(-1)
del cls_preds_

cost = (
Expand Down