Skip to content

Commit

Permalink
[Feature] Add DAFL Distillation (#235)
Browse files Browse the repository at this point in the history
* 1.Add DAFL, including config, DAFLLoss and readme. 2.Add DataFreeDistillationtillation. 3.Add Generator, including base_generator and dafl_generator. 4.Add get_module_device and set_requires_grad functions in utils.

* 1.Amend the file that report error in mypy test under py37, including gather_tensors, datafree_distillation, base_generator. 2.Revise other linting error.

* 1.Revise some docstrings.

* 1.Add UT for datafreedistillation. 2.Add all typing.hints.

* 1.Add UT for generators and gather_tensors.

* 1.Add assert of batch_size in base_generator

* 1.Isort

Co-authored-by: zhangzhongyu.vendor < zhangzhongyu.vendor@sensetime.com>
  • Loading branch information
wilxy and zhangzhongyu.vendor authored Aug 23, 2022
1 parent 72c1175 commit 57aec1f
Show file tree
Hide file tree
Showing 23 changed files with 1,092 additions and 18 deletions.
42 changes: 42 additions & 0 deletions configs/distill/mmcls/dafl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Data-Free Learning of Student Networks (DAFL)

> [Data-Free Learning of Student Networks](https://doi.org/10.1109/ICCV.2019.00361)
<!-- [ALGORITHM] -->

## Abstract

Learning portable neural networks is very essential for computer vision for the purpose that pre-trained heavy deep models can be well applied on edge devices such as mobile phones and micro sensors. Most existing deep neural network compression and speed-up methods are very effective for training compact deep models, when we can directly access the training dataset. However, training data for the given deep network are often unavailable due to some practice problems (e.g. privacy, legal issue, and transmission), and the architecture of the given network are also unknown except some interfaces. To this end, we propose a novel framework for training efficient deep neural networks by exploiting generative adversarial networks (GANs). To be specific, the pre-trained teacher networks are regarded as a fixed discriminator and the generator is utilized for deviating training samples which can obtain the maximum response on the discriminator. Then, an efficient network with smaller model size and computational complexity is trained using the generated data and the teacher network, simultaneously. Efficient student networks learned using the pro- posed Data-Free Learning (DAFL) method achieve 92.22% and 74.47% accuracies using ResNet-18 without any training data on the CIFAR-10 and CIFAR-100 datasets, respectively. Meanwhile, our student network obtains an 80.56% accuracy on the CelebA benchmark.

![pipeline](/docs/en/imgs/model_zoo/dafl/pipeline.png)

## Results and models

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :----------------------------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone (pretrain) & logits (train) | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.11 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |

## Citation

```latex
@inproceedings{DBLP:conf/iccv/ChenW0YLSXX019,
author = {Hanting Chen, Yunhe Wang, Chang Xu, Zhaohui Yang, Chuanjian Liu,
Boxin Shi, Chunjing Xu, Chao Xu and Qi Tian},
title = {Data-Free Learning of Student Networks},
booktitle = {2019 {IEEE/CVF} International Conference on Computer Vision, {ICCV}
2019, Seoul, Korea (South), October 27 - November 2, 2019},
pages = {3513--3521},
publisher = {{IEEE}},
year = {2019},
url = {https://doi.org/10.1109/ICCV.2019.00361},
doi = {10.1109/ICCV.2019.00361},
timestamp = {Mon, 17 May 2021 08:18:18 +0200},
biburl = {https://dblp.org/rec/conf/iccv/ChenW0YLSXX019.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
```

## Acknowledgement

Shout out to Davidgzx.
104 changes: 104 additions & 0 deletions configs/distill/mmcls/dafl/dafl_logits_r34_r18_8xb256_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
_base_ = [
'mmcls::_base_/datasets/cifar10_bs16.py',
'mmcls::_base_/schedules/cifar10_bs128.py',
'mmcls::_base_/default_runtime.py'
]

model = dict(
_scope_='mmrazor',
type='DAFLDataFreeDistillation',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# convert image from BGR to RGB
bgr_to_rgb=False),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
teachers=dict(
res34=dict(
build_cfg=dict(
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
pretrained=True),
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')),
generator=dict(
type='DAFLGenerator',
img_size=32,
latent_dim=1000,
hidden_channels=128),
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
distill_losses=dict(
loss_kl=dict(type='KLDivergence', tau=6, loss_weight=1)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='res34_fc')))),
generator_distiller=dict(
type='ConfigurableDistiller',
teacher_recorders=dict(
res34_neck_gap=dict(type='ModuleOutputs', source='res34.neck.gap'),
res34_fc=dict(type='ModuleOutputs', source='res34.head.fc')),
distill_losses=dict(
loss_res34_oh=dict(type='OnehotLikeLoss', loss_weight=0.05),
loss_res34_ie=dict(type='InformationEntropyLoss', loss_weight=5),
loss_res34_ac=dict(type='ActivationLoss', loss_weight=0.01)),
loss_forward_mappings=dict(
loss_res34_oh=dict(
preds_T=dict(from_student=False, recorder='res34_fc')),
loss_res34_ie=dict(
preds_T=dict(from_student=False, recorder='res34_fc')),
loss_res34_ac=dict(
feat_T=dict(from_student=False, recorder='res34_neck_gap')))))

# model wrapper
model_wrapper_cfg = dict(
type='mmengine.MMSeparateDistributedDataParallel',
broadcast_buffers=False,
find_unused_parameters=False)

find_unused_parameters = True

# optimizer wrapper
optim_wrapper = dict(
_delete_=True,
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(optimizer=dict(type='AdamW', lr=1e-1)),
generator=dict(optimizer=dict(type='AdamW', lr=1e-3)))

auto_scale_lr = dict(base_batch_size=256)

param_scheduler = dict(
_delete_=True,
architecture=[
dict(type='LinearLR', end=500, by_epoch=False, start_factor=0.0001),
dict(
type='MultiStepLR',
begin=500,
milestones=[100 * 120, 200 * 120],
by_epoch=False)
],
generator=dict(
type='LinearLR', end=500, by_epoch=False, start_factor=0.0001))

train_cfg = dict(
_delete_=True, by_epoch=False, max_iters=250 * 120, val_interval=150)

train_dataloader = dict(
batch_size=256, sampler=dict(type='InfiniteSampler', shuffle=True))
val_dataloader = dict(batch_size=256)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

default_hooks = dict(
logger=dict(type='LoggerHook', interval=75, log_metric_by_epoch=False),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=150, max_keep_ckpts=2))

log_processor = dict(by_epoch=False)
# Must set diff_rank_seed to True!
randomness = dict(seed=None, diff_rank_seed=True)
Binary file added docs/en/imgs/model_zoo/dafl/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions mmrazor/models/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseAlgorithm
from .distill import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
from .distill import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, SelfDistill, SingleTeacherDistill)
from .nas import SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP
from .pruning import SlimmableNetwork, SlimmableNetworkDDP

__all__ = [
'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS',
'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP',
'Darts', 'DartsDDP', 'SelfDistill'
'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation',
'DAFLDataFreeDistillation'
]
9 changes: 7 additions & 2 deletions mmrazor/models/algorithms/distill/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .configurable import FpnTeacherDistill, SelfDistill, SingleTeacherDistill
from .configurable import (DAFLDataFreeDistillation, DataFreeDistillation,
FpnTeacherDistill, SelfDistill,
SingleTeacherDistill)

__all__ = ['SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill']
__all__ = [
'SingleTeacherDistill', 'FpnTeacherDistill', 'SelfDistill',
'DataFreeDistillation', 'DAFLDataFreeDistillation'
]
7 changes: 6 additions & 1 deletion mmrazor/models/algorithms/distill/configurable/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .datafree_distillation import (DAFLDataFreeDistillation,
DataFreeDistillation)
from .fpn_teacher_distill import FpnTeacherDistill
from .self_distill import SelfDistill
from .single_teacher_distill import SingleTeacherDistill

__all__ = ['SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill']
__all__ = [
'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill',
'DataFreeDistillation', 'DAFLDataFreeDistillation'
]
Loading

0 comments on commit 57aec1f

Please sign in to comment.