From 2deb43657162312a2428fc2acf997296e99f8b28 Mon Sep 17 00:00:00 2001 From: xylieong <61135607+xylieong@users.noreply.github.com> Date: Thu, 19 May 2022 21:48:44 +0800 Subject: [PATCH] Added Windows cmd to count GPU devices (#7891) * Added Windows cmd to count GPU devices * Cleanup Co-authored-by: Glenn Jocher --- utils/torch_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 86371f127b45..ca32b1999e81 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -40,10 +40,10 @@ def torch_distributed_zero_first(local_rank: int): def device_count(): - # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux. - assert platform.system() == 'Linux', 'device_count() function only works on Linux' + # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows + assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows' try: - cmd = 'nvidia-smi -L | wc -l' + cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]) except Exception: return 0