Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 18, 2021
1 parent e8493c6 commit fb342fc
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
Expand All @@ -503,11 +502,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
opt = parser.parse_args()

LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))

# Set DDP variables
set_logging(RANK)
if RANK in [-1, 0]:
check_git_status()
Expand Down Expand Up @@ -535,7 +529,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
device = select_device(opt.device, batch_size=opt.batch_size)
print({'RANK': RANK, 'LOCAL_RANK': LOCAL_RANK, 'WORLD_SIZE': WORLD_SIZE})
if LOCAL_RANK != -1:
assert torch.cuda.device_count() > LOCAL_RANK
assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command'
torch.cuda.set_device(LOCAL_RANK)
device = torch.device('cuda', LOCAL_RANK)
dist.init_process_group(backend="gloo") # distributed backend
Expand Down

0 comments on commit fb342fc

Please sign in to comment.