Skip to content

Commit

Permalink
dataparallel device ids fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewilyas committed Dec 1, 2020
1 parent 67ec12b commit 15c9ff6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions robustness/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def eval_model(args, model, loader, store):
if store: store[consts.LOGS_TABLE].append_row(log_info)
return log_info

def train_model(args, model, loaders, *, checkpoint=None,
def train_model(args, model, loaders, *, checkpoint=None, dp_device_ids=None,
store=None, update_params=None, disable_no_grad=False):
"""
Main function for training a model.
Expand Down Expand Up @@ -266,6 +266,8 @@ def train_model(args, model, loaders, *, checkpoint=None,
`(train_loader, val_loader)`
checkpoint (dict) : a loaded checkpoint previously saved by this library
(if resuming from checkpoint)
dp_device_ids (list|None) : if not ``None``, a list of device ids to
use for DataParallel.
store (cox.Store) : a cox store for logging training progress
update_params (list) : list of parameters to use for training, if None
then all parameters in the model are used (useful for transfer
Expand Down Expand Up @@ -293,7 +295,7 @@ def train_model(args, model, loaders, *, checkpoint=None,

# Put the model into parallel mode
assert not hasattr(model, "module"), "model is already in DataParallel."
model = ch.nn.DataParallel(model).cuda()
model = ch.nn.DataParallel(model, device_ids=dp_device_ids).cuda()

best_prec1, start_epoch = (0, 0)
if checkpoint:
Expand Down

0 comments on commit 15c9ff6

Please sign in to comment.