Skip to content

Commit

Permalink
Distillation support for torchvision script (#1310)
Browse files Browse the repository at this point in the history
* Add support for `self` distillation and `disable`

* Pull out model creation into a method

* Add support to distill with another model

* Add modifier loss update before backward pass

* bugfix, set loss

* Update src/sparseml/pytorch/torchvision/train.py

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
  • Loading branch information
rahul-tuli and corey-nm committed Jan 11, 2023
1 parent da3ffb0 commit adb30a0
Showing 1 changed file with 103 additions and 23 deletions.
126 changes: 103 additions & 23 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import warnings
from functools import update_wrapper
from types import SimpleNamespace
from typing import Callable
from typing import Callable, Optional

import torch
import torch.utils.data
Expand Down Expand Up @@ -63,6 +63,7 @@ def train_one_epoch(
epoch: int,
args,
log_metrics_fn: Callable[[str, utils.MetricLogger, int, int], None],
manager=None,
model_ema=None,
scaler=None,
) -> utils.MetricLogger:
Expand Down Expand Up @@ -91,13 +92,24 @@ def train_one_epoch(
start_time = time.time()
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
outputs = output = model(image)
if isinstance(output, tuple):
# NOTE: sparseml models return two things (logits & probs)
output = output[0]
loss = criterion(output, target)

if steps_accumulated % accum_steps == 0:
if manager is not None:
loss = manager.loss_update(
loss=loss,
module=model,
optimizer=optimizer,
epoch=epoch,
steps_per_epoch=len(data_loader) / accum_steps,
student_outputs=outputs,
student_inputs=image,
)

# first: do training to consume gradients
if scaler is not None:
scaler.scale(loss).backward()
Expand Down Expand Up @@ -348,27 +360,28 @@ def collate_fn(batch):
)

_LOGGER.info("Creating model")
if args.arch_key in ModelRegistry.available_keys():
with torch_distributed_zero_first(args.rank if args.distributed else None):
model = ModelRegistry.create(
key=args.arch_key,
pretrained=args.pretrained,
pretrained_path=args.checkpoint_path,
pretrained_dataset=args.pretrained_dataset,
num_classes=num_classes,
)
elif args.arch_key in torchvision.models.__dict__:
# fall back to torchvision
model = torchvision.models.__dict__[args.arch_key](
pretrained=args.pretrained, num_classes=num_classes
)
if args.checkpoint_path is not None:
load_model(args.checkpoint_path, model, strict=True)
else:
raise ValueError(
f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models"
local_rank = args.rank if args.distributed else None
model = _create_model(
arch_key=args.arch_key,
local_rank=local_rank,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint_path,
pretrained_dataset=args.pretrained_dataset,
device=device,
num_classes=num_classes,
)

if args.distill_teacher not in ["self", "disable", None]:
_LOGGER.info("Instantiating teacher")
args.distill_teacher = _create_model(
arch_key=args.teacher_arch_key,
local_rank=local_rank,
pretrained=True, # teacher is always pretrained
pretrained_dataset=args.pretrained_teacher_dataset,
checkpoint_path=args.distill_teacher,
device=device,
num_classes=num_classes,
)
model.to(device)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -542,7 +555,12 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
)

if manager is not None:
manager.initialize(model, epoch=args.start_epoch, loggers=logger)
manager.initialize(
model,
epoch=args.start_epoch,
loggers=logger,
distillation_teacher=args.distill_teacher,
)
optimizer = manager.modify(
model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch
)
Expand Down Expand Up @@ -578,6 +596,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
epoch,
args,
log_metrics,
manager=manager,
model_ema=model_ema,
scaler=scaler,
)
Expand Down Expand Up @@ -650,6 +669,39 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
_LOGGER.info(f"Training time {total_time_str}")


def _create_model(
arch_key: Optional[str] = None,
local_rank=None,
pretrained: Optional[bool] = False,
checkpoint_path: Optional[str] = None,
pretrained_dataset: Optional[str] = None,
device=None,
num_classes=None,
):
if arch_key in ModelRegistry.available_keys():
with torch_distributed_zero_first(local_rank):
model = ModelRegistry.create(
key=arch_key,
pretrained=pretrained,
pretrained_path=checkpoint_path,
pretrained_dataset=pretrained_dataset,
num_classes=num_classes,
)
elif arch_key in torchvision.models.__dict__:
# fall back to torchvision
model = torchvision.models.__dict__[arch_key](
pretrained=pretrained, num_classes=num_classes
)
if checkpoint_path is not None:
load_model(checkpoint_path, model, strict=True)
else:
raise ValueError(
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
)
model.to(device)
return model


def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
lr_scheduler = None

Expand Down Expand Up @@ -1026,6 +1078,34 @@ def new_func(*args, **kwargs):
help="Save the best validation result after the given "
"epoch completes until the end of training",
)
@click.option(
"--distill-teacher",
default=None,
type=str,
help="Teacher model for distillation (a trained image classification model)"
" can be set to 'self' for self-distillation and 'disable' to switch-off"
" distillation, additionally can also take in a SparseZoo stub",
)
@click.option(
"--pretrained-teacher-dataset",
default=None,
type=str,
help=(
"The dataset to load pretrained weights for the teacher"
"Load the default dataset for the architecture if set to None. "
"examples:`imagenet`, `cifar10`, etc..."
),
)
@click.option(
"--teacher-arch-key",
default=None,
type=str,
help=(
"The architecture key for teacher image classification model; "
"example: `resnet50`, `mobilenet`. "
"Note: Will be read from the checkpoint if not specified"
),
)
@click.pass_context
def cli(ctx, **kwargs):
"""
Expand Down

0 comments on commit adb30a0

Please sign in to comment.