Skip to content

Commit

Permalink
DataParallel support for torchvision (#1332)
Browse files Browse the repository at this point in the history
* Add: `DataParallel` support to torchvision train script

* Resolve merge conflicts + quality

* Update: _create_model to return the right device
Update: all _create_model calling code to accept a third argument
Update: device to maybe_dp_device after teacher creation

---------

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
  • Loading branch information
rahul-tuli and corey-nm committed Feb 10, 2023
1 parent 83f5d2b commit a8a6992
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from sparseml.pytorch.torchvision import presets, transforms, utils
from sparseml.pytorch.torchvision.sampler import RASampler
from sparseml.pytorch.utils.helpers import (
default_device,
download_framework_model_by_recipe_type,
torch_distributed_zero_first,
)
Expand All @@ -49,7 +50,7 @@
TensorBoardLogger,
WANDBLogger,
)
from sparseml.pytorch.utils.model import load_model
from sparseml.pytorch.utils.model import load_model, model_to_device
from sparsezoo import Model


Expand Down Expand Up @@ -332,7 +333,10 @@ def main(args):

_LOGGER.info(args)

device = torch.device(args.device)
if not args.device:
args.device = default_device()

device = args.device

if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
Expand Down Expand Up @@ -382,7 +386,7 @@ def collate_fn(batch):

_LOGGER.info("Creating model")
local_rank = args.rank if args.distributed else None
model, arch_key = _create_model(
model, arch_key, maybe_dp_device = _create_model(
arch_key=args.arch_key,
local_rank=local_rank,
pretrained=args.pretrained,
Expand All @@ -394,7 +398,7 @@ def collate_fn(batch):

if args.distill_teacher not in ["self", "disable", None]:
_LOGGER.info("Instantiating teacher")
distill_teacher, _ = _create_model(
distill_teacher, _, _ = _create_model(
arch_key=args.teacher_arch_key,
local_rank=local_rank,
pretrained=True, # teacher is always pretrained
Expand All @@ -405,6 +409,7 @@ def collate_fn(batch):
)
else:
distill_teacher = args.distill_teacher
device = maybe_dp_device

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -763,8 +768,8 @@ def _create_model(
raise ValueError(
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
)
model.to(device)
return model, arch_key
model, device, _ = model_to_device(model=model, device=device)
return model, arch_key, device


def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
Expand Down Expand Up @@ -922,9 +927,11 @@ def new_func(*args, **kwargs):
)
@click.option(
"--device",
default="cuda",
default=None,
type=str,
help="device (Use cuda or cpu)",
help=(
"device (Use cuda for all gpus, else use `cuda:device_id,device_id` " "or cpu)"
),
)
@click.option(
"-b",
Expand Down

0 comments on commit a8a6992

Please sign in to comment.