Skip to content

Commit

Permalink
Enable running pytorch.torchvision.train with distributed data parall…
Browse files Browse the repository at this point in the history
…el (#1698)

* update get_named_layers_and_params_by_regex in src/sparseml/pytorch/utils/helpers.py to still match in the DDP training case, where module. is prepended to layer names

* make pytorch.torchvision.train.py work with torch.distributed.launch

* undo accidental add

* expand comment and fix for CLI

* make quality

* fix overly long line
  • Loading branch information
ohaijen authored Aug 7, 2023
1 parent b52a13d commit 8e4dc20
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@

# Adapted from https://github.com/pytorch/vision


# Note that Distributed-Data-Parallel (DDP) mode cannot be
# activated when running this code using the CLI
# (ie, by using sparseml.image_classification.train).
# Rather, Data-Parallel (DP) mode will be used.
# Please run as follows to run in DDP mode:
# CUDA_VISIBLE_DEVICES=<GPUs> python -m torch.distributed.launch \
# --nproc_per_node <NUM GPUs> \
# sparseml.torchvision.train \
# <TRAIN.PY ARGUMENTS>

import datetime
import logging
import math
Expand Down Expand Up @@ -389,7 +400,7 @@ def collate_fn(batch):
)

_LOGGER.info("Creating model")
local_rank = args.rank if args.distributed else None
local_rank = args.local_rank if args.distributed else None
model, arch_key, maybe_dp_device = _create_model(
arch_key=args.arch_key,
local_rank=local_rank,
Expand Down Expand Up @@ -800,7 +811,16 @@ def _create_model(
raise ValueError(
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
)
model, device, _ = model_to_device(model=model, device=device)
ddp = False
if local_rank is not None:
torch.cuda.set_device(local_rank)
device = local_rank
ddp = True
model, device, _ = model_to_device(
model=model,
device=device,
ddp=ddp,
)
return model, arch_key, device


Expand Down Expand Up @@ -1236,6 +1256,14 @@ def new_func(*args, **kwargs):
"Note: Will use ImageNet values if not specified."
),
)
@click.option(
"--local_rank",
"--local-rank",
type=int,
default=None,
help="Local rank for distributed training",
hidden=True, # should not be modified by user
)
@click.pass_context
def cli(ctx, **kwargs):
"""
Expand Down

0 comments on commit 8e4dc20

Please sign in to comment.