diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index d80ad837a30..9045c5a452e 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -40,6 +40,7 @@ from sparseml.pytorch.torchvision import presets, transforms, utils from sparseml.pytorch.torchvision.sampler import RASampler from sparseml.pytorch.utils.helpers import ( + default_device, download_framework_model_by_recipe_type, torch_distributed_zero_first, ) @@ -49,7 +50,7 @@ TensorBoardLogger, WANDBLogger, ) -from sparseml.pytorch.utils.model import load_model +from sparseml.pytorch.utils.model import load_model, model_to_device from sparsezoo import Model @@ -332,7 +333,10 @@ def main(args): _LOGGER.info(args) - device = torch.device(args.device) + if not args.device: + args.device = default_device() + + device = args.device if args.use_deterministic_algorithms: torch.backends.cudnn.benchmark = False @@ -382,7 +386,7 @@ def collate_fn(batch): _LOGGER.info("Creating model") local_rank = args.rank if args.distributed else None - model, arch_key = _create_model( + model, arch_key, maybe_dp_device = _create_model( arch_key=args.arch_key, local_rank=local_rank, pretrained=args.pretrained, @@ -394,7 +398,7 @@ def collate_fn(batch): if args.distill_teacher not in ["self", "disable", None]: _LOGGER.info("Instantiating teacher") - distill_teacher, _ = _create_model( + distill_teacher, _, _ = _create_model( arch_key=args.teacher_arch_key, local_rank=local_rank, pretrained=True, # teacher is always pretrained @@ -405,6 +409,7 @@ def collate_fn(batch): ) else: distill_teacher = args.distill_teacher + device = maybe_dp_device if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -763,8 +768,8 @@ def _create_model( raise ValueError( f"Unable to find {arch_key} in ModelRegistry or in torchvision.models" ) - model.to(device) - return model, arch_key + model, device, _ = model_to_device(model=model, device=device) + return model, arch_key, device def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None): @@ -922,9 +927,11 @@ def new_func(*args, **kwargs): ) @click.option( "--device", - default="cuda", + default=None, type=str, - help="device (Use cuda or cpu)", + help=( + "device (Use cuda for all gpus, else use `cuda:device_id,device_id` " "or cpu)" + ), ) @click.option( "-b",