Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable running pytorch.torchvision.train with distributed data parallel #1698

Merged
merged 9 commits into from
Aug 7, 2023
32 changes: 30 additions & 2 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@

# Adapted from https://github.com/pytorch/vision


# Note that Distributed-Data-Parallel (DDP) mode cannot be
# activated when running this code using the CLI
# (ie, by using sparseml.image_classification.train).
# Rather, Data-Parallel (DP) mode will be used.
# Please run as follows to run in DDP mode:
# CUDA_VISIBLE_DEVICES=<GPUs> python -m torch.distributed.launch \
# --nproc_per_node <NUM GPUs> \
# sparseml.torchvision.train \
# <TRAIN.PY ARGUMENTS>

import datetime
import logging
import math
Expand Down Expand Up @@ -389,7 +400,7 @@ def collate_fn(batch):
)

_LOGGER.info("Creating model")
local_rank = args.rank if args.distributed else None
local_rank = args.local_rank if args.distributed else None
model, arch_key, maybe_dp_device = _create_model(
arch_key=args.arch_key,
local_rank=local_rank,
Expand Down Expand Up @@ -800,7 +811,16 @@ def _create_model(
raise ValueError(
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
)
model, device, _ = model_to_device(model=model, device=device)
ddp = False
if local_rank is not None:
torch.cuda.set_device(local_rank)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice - looks like it will make a lot of our defaults work out of the box

device = local_rank
ddp = True
model, device, _ = model_to_device(
model=model,
device=device,
ddp=ddp,
)
return model, arch_key, device


Expand Down Expand Up @@ -1236,6 +1256,14 @@ def new_func(*args, **kwargs):
"Note: Will use ImageNet values if not specified."
),
)
@click.option(
"--local_rank",
"--local-rank",
type=int,
default=None,
help="Local rank for distributed training",
hidden=True, # should not be modified by user
)
@click.pass_context
def cli(ctx, **kwargs):
"""
Expand Down
Loading