From 600b0032247d6db8460eb2a862c4b698d577e659 Mon Sep 17 00:00:00 2001 From: xylieong <61135607+xylieong@users.noreply.github.com> Date: Thu, 19 May 2022 21:34:02 +0800 Subject: [PATCH 1/2] Added Windows cmd to count GPU devices --- utils/torch_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 86371f127b45..8782735a8a6c 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -41,9 +41,13 @@ def torch_distributed_zero_first(local_rank: int): def device_count(): # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux. - assert platform.system() == 'Linux', 'device_count() function only works on Linux' + system_platform = platform.system() + assert system_platform == 'Linux' or 'Windows', 'device_count() function only works on Linux or Windows' try: - cmd = 'nvidia-smi -L | wc -l' + if system_platform == 'Linux': + cmd = 'nvidia-smi -L | wc -l' + elif system_platform == 'Windows': + cmd = 'nvidia-smi -L | find /c /v ""' return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]) except Exception: return 0 From 1a3aafc908c4320da38f0a5d278d0805a8b8c1de Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 19 May 2022 15:47:57 +0200 Subject: [PATCH 2/2] Cleanup --- utils/torch_utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 8782735a8a6c..ca32b1999e81 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -40,14 +40,10 @@ def torch_distributed_zero_first(local_rank: int): def device_count(): - # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux. - system_platform = platform.system() - assert system_platform == 'Linux' or 'Windows', 'device_count() function only works on Linux or Windows' + # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows + assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows' try: - if system_platform == 'Linux': - cmd = 'nvidia-smi -L | wc -l' - elif system_platform == 'Windows': - cmd = 'nvidia-smi -L | find /c /v ""' + cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]) except Exception: return 0