Skip to content

Commit

Permalink
fix handling for device
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 8, 2023
1 parent 04625e0 commit 8ff0109
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def gpu_memory_usage_smi(device=0):


def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
if not torch.cuda.is_available() or device == "auto":
return (0, 0, 0)

usage, cache, misc = gpu_memory_usage_all(device)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def get_device():
except Exception: # pylint: disable=broad-exception-caught
return "cpu"

cfg.device = get_device()
if cfg.world_size == 1:
cfg.device_map = "auto"
else:
cfg.device = get_device()
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
Expand Down

0 comments on commit 8ff0109

Please sign in to comment.