Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distillation support for torchvision script #1310

Merged
merged 9 commits into from
Jan 11, 2023
110 changes: 89 additions & 21 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import warnings
from functools import update_wrapper
from types import SimpleNamespace
from typing import Optional

import torch
import torch.utils.data
Expand Down Expand Up @@ -332,27 +333,28 @@ def collate_fn(batch):
)

_LOGGER.info("Creating model")
if args.arch_key in ModelRegistry.available_keys():
with torch_distributed_zero_first(args.rank if args.distributed else None):
model = ModelRegistry.create(
key=args.arch_key,
pretrained=args.pretrained,
pretrained_path=args.checkpoint_path,
pretrained_dataset=args.pretrained_dataset,
num_classes=num_classes,
)
elif args.arch_key in torchvision.models.__dict__:
# fall back to torchvision
model = torchvision.models.__dict__[args.arch_key](
pretrained=args.pretrained, num_classes=num_classes
)
if args.checkpoint_path is not None:
load_model(args.checkpoint_path, model, strict=True)
else:
raise ValueError(
f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models"
local_rank = args.rank if args.distributed else None
model = _create_model(
arch_key=args.arch_key,
local_rank=local_rank,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint_path,
pretrained_dataset=args.pretrained_dataset,
device=device,
num_classes=num_classes,
)

if args.distill_teacher not in ["self", "disable", None]:
_LOGGER.info("Instantiating teacher")
args.distill_teacher = _create_model(
arch_key=args.teacher_arch_key,
local_rank=local_rank,
pretrained=True, # teacher is always pretrained
pretrained_dataset=args.pretrained_teacher_dataset,
checkpoint_path=args.distill_teacher,
device=device,
num_classes=num_classes,
)
model.to(device)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -516,7 +518,12 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
)

if manager is not None:
manager.initialize(model, epoch=args.start_epoch, loggers=logger)
manager.initialize(
model,
epoch=args.start_epoch,
loggers=logger,
distillation_teacher=args.distill_teacher,
)
optimizer = manager.modify(
model, optimizer, len(data_loader), epoch=args.start_epoch
)
Expand Down Expand Up @@ -623,6 +630,39 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
_LOGGER.info(f"Training time {total_time_str}")


def _create_model(
arch_key: Optional[str] = None,
local_rank=None,
pretrained: Optional[bool] = False,
checkpoint_path: Optional[str] = None,
pretrained_dataset: Optional[str] = None,
device=None,
num_classes=None,
):
if arch_key in ModelRegistry.available_keys():
with torch_distributed_zero_first(local_rank):
model = ModelRegistry.create(
key=arch_key,
pretrained=pretrained,
pretrained_path=checkpoint_path,
pretrained_dataset=pretrained_dataset,
num_classes=num_classes,
)
elif arch_key in torchvision.models.__dict__:
# fall back to torchvision
model = torchvision.models.__dict__[arch_key](
pretrained=pretrained, num_classes=num_classes
)
if checkpoint_path is not None:
load_model(checkpoint_path, model, strict=True)
else:
raise ValueError(
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
)
model.to(device)
return model


def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
lr_scheduler = None

Expand Down Expand Up @@ -993,6 +1033,34 @@ def new_func(*args, **kwargs):
help="Save the best validation result after the given "
"epoch completes until the end of training",
)
@click.option(
"--distill-teacher",
default=None,
type=str,
help="Teacher model for distillation (a trained image classification model)"
" can be set to 'self' for self-distillation and 'disable' to switch-off"
" distillation, additionally can also take in a SparseZoo stub",
)
@click.option(
"--pretrained-teacher-dataset",
default=None,
type=str,
help=(
"The dataset to load pretrained weights for the teacher"
"Load the default dataset for the architecture if set to None. "
"examples:`imagenet`, `cifar10`, etc..."
),
)
@click.option(
"--teacher-arch-key",
default=None,
type=str,
help=(
"The architecture key for teacher image classification model; "
"example: `resnet50`, `mobilenet`. "
"Note: Will be read from the checkpoint if not specified"
),
)
@click.pass_context
def cli(ctx, **kwargs):
"""
Expand Down