diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 86371f127b45..ca32b1999e81 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -40,10 +40,10 @@ def torch_distributed_zero_first(local_rank: int): def device_count(): - # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux. - assert platform.system() == 'Linux', 'device_count() function only works on Linux' + # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows + assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows' try: - cmd = 'nvidia-smi -L | wc -l' + cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]) except Exception: return 0