diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 2a45f434c6a5..6d85e09a2eb4 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -61,7 +61,7 @@ def select_device(device='', batch_size=0, newline=True): if cpu: os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False elif device: # non-cpu device requested - nd = torch.cuda.device_count() # number of CUDA devices + nd = len(os.getenv('CUDA_VISIBLE_DEVICES', '').replace(',','')) # number of CUDA devices assert torch.cuda.is_available(), 'CUDA is not available, use `--device cpu` or do not pass a --device' assert nd > int(max(device.split(','))), f'Invalid `--device {device}` request, valid devices are 0 - {nd - 1}' os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable (must be after asserts)