diff --git a/utils/torch_utils.py b/utils/torch_utils.py index e441b8e0626a..2cb09e71ce71 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -61,7 +61,7 @@ def select_device(device='', batch_size=None): os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability - cuda = torch.cuda.is_available() and not cpu + cuda = not cpu and torch.cuda.is_available() if cuda: n = torch.cuda.device_count() if n > 1 and batch_size: # check that batch_size is compatible with device_count