Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distillation support for torchvision script #1310

Merged
merged 9 commits into from
Jan 11, 2023

Conversation

rahul-tuli
Copy link
Member

@rahul-tuli rahul-tuli commented Jan 10, 2023

The goal of this PR is to add distillation support to our pytorch/torchvision integration
Test recipe:

distillation.yaml:

num_epochs: &num_epochs 10
lr: &lr 0.008

training_modifiers:
  - !EpochRangeModifier
    start_epoch: 0.0
    end_epoch: *num_epochs

  - !SetLearningRateModifier
    start_epoch: 0.0
    learning_rate: *lr

  - !DistillationModifier
     start_epoch: 0.0
     hardness: 0.5
     temperature: 2.0
     distill_output_keys: [0]

Test commands, (run manually):

  • self distillation
sparseml.image_classification.train \ 
    --recipe distillation.yaml --pretrained True --pretrained-dataset imagenette \
    --arch-key resnet50 --dataset-path /home/rahul/datasets/imagenette/imagenette-160 \
    --batch-size 128 --opt SGD --output-dir ./training-runs/image_classification-pretrained \
    --distill-teacher self
  • This command should fail (distillation recipe given but no --distill-teacher specified)
sparseml.image_classification.train \ 
    --recipe distillation.yaml --pretrained True --pretrained-dataset imagenette \
    --arch-key resnet50 --dataset-path /home/rahul/datasets/imagenette/imagenette-160 \
    --batch-size 128 --opt SGD --output-dir ./training-runs/image_classification-pretrained 
  • Disable distillation
sparseml.image_classification.train \ 
    --recipe distillation.yaml --pretrained True --pretrained-dataset imagenette \
    --arch-key resnet50 --dataset-path /home/rahul/datasets/imagenette/imagenette-160 \
    --batch-size 128 --opt SGD --output-dir ./training-runs/image_classification-pretrained \
    --distill-teacher disable
  • Distill a mobilenet using a resnet50 teacher from sparsezoo
sparseml.image_classification.train \ 
    --recipe distillation.yaml --pretrained True --pretrained-dataset imagenette \
    --arch-key mobilenet --dataset-path /home/rahul/datasets/imagenette/imagenette-160 \
    --batch-size 128 --opt SGD --output-dir ./training-runs/image_classification-pretrained \
    --distill-teacher zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/base-none \
    --pretrained-teacher-dataset imagenet --teacher-arch-key resnet50

@rahul-tuli rahul-tuli marked this pull request as ready for review January 10, 2023 18:51
@rahul-tuli rahul-tuli changed the title [WIP] Distillation support for torchvision script Distillation support for torchvision script Jan 10, 2023
Copy link
Contributor

@corey-nm corey-nm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes look good, but need to call manager.update_loss (or whatever the call is) to actually use distillation loss

@rahul-tuli
Copy link
Member Author

changes look good, but need to call manager.update_loss (or whatever the call is) to actually use distillation loss

Great catch, updated!

KSGulin
KSGulin previously approved these changes Jan 11, 2023
Copy link
Contributor

@KSGulin KSGulin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left one comment, otherwise LGTM!

src/sparseml/pytorch/torchvision/train.py Show resolved Hide resolved
bfineran
bfineran previously approved these changes Jan 11, 2023
Copy link
Member

@bfineran bfineran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great @rahul-tuli

Copy link
Contributor

@corey-nm corey-nm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like loss_update returns the new loss to use! So close 😀

@rahul-tuli rahul-tuli dismissed stale reviews from bfineran and KSGulin via 735956e January 11, 2023 14:48
corey-nm
corey-nm previously approved these changes Jan 11, 2023
Copy link
Contributor

@corey-nm corey-nm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 LETS GOOOOOO

bfineran
bfineran previously approved these changes Jan 11, 2023
@rahul-tuli rahul-tuli dismissed stale reviews from bfineran and corey-nm via 6be8dd8 January 11, 2023 15:02
KSGulin
KSGulin previously approved these changes Jan 11, 2023
Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
@rahul-tuli rahul-tuli merged commit adb30a0 into main Jan 11, 2023
@rahul-tuli rahul-tuli deleted the feature/torchvision-distillation-support branch January 11, 2023 15:20
bfineran pushed a commit that referenced this pull request Feb 2, 2023
* Add support for `self` distillation and `disable`

* Pull out model creation into a method

* Add support to distill with another model

* Add modifier loss update before backward pass

* bugfix, set loss

* Update src/sparseml/pytorch/torchvision/train.py

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
bfineran pushed a commit that referenced this pull request Feb 3, 2023
* Add support for `self` distillation and `disable`

* Pull out model creation into a method

* Add support to distill with another model

* Add modifier loss update before backward pass

* bugfix, set loss

* Update src/sparseml/pytorch/torchvision/train.py

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants