-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add Decoupled KD Loss #222
Conversation
…kd_loss Conflicts: mmrazor/models/losses/__init__.py
|
||
![avatar](../../../docs/imgs/model_zoo/dkd/dkd.png) | ||
|
||
## Results and models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update readme
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
# Batch augmentation may convert labels to one-hot format scores. | ||
gt_labels = torch.stack([i.gt_label.score for i in data_samples]) | ||
else: | ||
gt_labels = torch.hstack([i.gt_label.label for i in data_samples]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that data_samples
is only used to get gt_labels
. Could we hook the gt_labels
directly? (By registering a hook on head.loss_module
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since distill head in gml is deprecated and replaced by recorder
, connector
and loss_forward_mapping
, distill loss can only receive mapped params and cannot be hooked under existing framework.
mask1: torch.Tensor, | ||
mask2: torch.Tensor, | ||
) -> torch.Tensor: | ||
t1 = (target * mask1).sum(dim=1, keepdims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better if the parameters' names make sense.
sh tools/slurm_train.sh $PARTITION $JOB_NAME \ | ||
configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py \ | ||
$DISTILLATION_WORK_DIR | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a special acknowledgement:
### Acknowledgement
Shout out to @Davidgzx for his special contribution.
Codecov Report
@@ Coverage Diff @@
## dev-1.x #222 +/- ##
==========================================
- Coverage 0.63% 0.62% -0.01%
==========================================
Files 124 125 +1
Lines 4601 4642 +41
Branches 721 719 -2
==========================================
Hits 29 29
- Misses 4567 4608 +41
Partials 5 5
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
* add DKDLoss, config * linting * linting * reset default reduction * dkd ut * Update decoupled_kd.py * Update decoupled_kd.py * Update decoupled_kd.py * fix commit * fix readme * fix comments * linting comment * rename loss params * fix docstring * Update decoupled_kd.py * fix gt from config * merge fix * fix ut & wsld * Update README.md * Update README.md * add Acknowledgement * Update README.md * Update README.md * Update README.md * Update README.md * fix readme style * fix md Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
* add doc * Update * Resolve comments
Dear spynccat, |
Motivation
Add Decoupled KD Loss (DKDLoss).
Add DKDLoss config, readme and UT.
Add UT in
test_distillation_losses
for DKDLoss.