Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Decoupled KD Loss #222

Merged
merged 31 commits into from
Aug 15, 2022
Merged

Conversation

spynccat
Copy link
Contributor

@spynccat spynccat commented Aug 5, 2022

Motivation

Add Decoupled KD Loss (DKDLoss).
Add DKDLoss config, readme and UT.
Add UT in test_distillation_losses for DKDLoss.

@wilxy wilxy requested review from pppppM and sunnyxiaohu and removed request for sunnyxiaohu and pppppM August 8, 2022 03:18

![avatar](../../../docs/imgs/model_zoo/dkd/dkd.png)

## Results and models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update readme

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
# Batch augmentation may convert labels to one-hot format scores.
gt_labels = torch.stack([i.gt_label.score for i in data_samples])
else:
gt_labels = torch.hstack([i.gt_label.label for i in data_samples])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that data_samples is only used to get gt_labels. Could we hook the gt_labels directly? (By registering a hook on head.loss_module)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since distill head in gml is deprecated and replaced by recorder, connector and loss_forward_mapping, distill loss can only receive mapped params and cannot be hooked under existing framework.

mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
mask1: torch.Tensor,
mask2: torch.Tensor,
) -> torch.Tensor:
t1 = (target * mask1).sum(dim=1, keepdims=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better if the parameters' names make sense.

mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
mmrazor/models/losses/decoupled_kd.py Show resolved Hide resolved
mmrazor/models/losses/decoupled_kd.py Show resolved Hide resolved
mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
mmrazor/models/losses/decoupled_kd.py Outdated Show resolved Hide resolved
@spynccat spynccat requested a review from HIT-cwh August 10, 2022 02:17
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py \
$DISTILLATION_WORK_DIR
```
Copy link
Contributor

@fpshuang fpshuang Aug 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a special acknowledgement:

### Acknowledgement

Shout out to @Davidgzx for his special contribution.

@codecov
Copy link

codecov bot commented Aug 15, 2022

Codecov Report

Merging #222 (02f4425) into dev-1.x (6e8ebfd) will decrease coverage by 0.00%.
The diff coverage is 0.00%.

❗ Current head 02f4425 differs from pull request most recent head 33b3883. Consider uploading reports for the commit 33b3883 to get more accurate results

@@            Coverage Diff             @@
##           dev-1.x    #222      +/-   ##
==========================================
- Coverage     0.63%   0.62%   -0.01%     
==========================================
  Files          124     125       +1     
  Lines         4601    4642      +41     
  Branches       721     719       -2     
==========================================
  Hits            29      29              
- Misses        4567    4608      +41     
  Partials         5       5              
Flag Coverage Δ
unittests 0.62% <0.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmrazor/models/losses/__init__.py 0.00% <0.00%> (ø)
mmrazor/models/losses/decoupled_kd.py 0.00% <0.00%> (ø)
.../models/losses/weighted_soft_label_distillation.py 0.00% <0.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@pppppM pppppM merged commit a1937fd into open-mmlab:dev-1.x Aug 15, 2022
LKJacky pushed a commit to LKJacky/mmrazor that referenced this pull request Aug 20, 2022
* add DKDLoss, config

* linting

* linting

* reset default reduction

* dkd ut

* Update decoupled_kd.py

* Update decoupled_kd.py

* Update decoupled_kd.py

* fix commit

* fix readme

* fix comments

* linting comment

* rename loss params

* fix docstring

* Update decoupled_kd.py

* fix gt from config

* merge fix

* fix ut & wsld

* Update README.md

* Update README.md

* add Acknowledgement

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* fix readme style

* fix md

Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
humu789 pushed a commit to humu789/mmrazor that referenced this pull request Feb 13, 2023
* add doc

* Update

* Resolve comments
@openmmlab-bot
Copy link
Collaborator

Dear spynccat,
First of all, we want to express our gratitude for your significant PR in the MMRazor project. Your contribution is highly appreciated, and we are grateful for your efforts in helping improve this open-source project during your personal time. We believe that many developers will benefit from your PR.
If you are Chinese or have WeChat,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:)
We would also like to invite you to join our Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. We look forward to seeing you there! Join us :https://discord.gg/raweFPmdzG
Thank you again for your contribution❤
Best regards! @spynccat

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants