diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index b460b2ba7c..794c1e50df 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -1,14 +1,40 @@ """Benchmarking and measurement utilities""" +import functools import pynvml import torch from pynvml.nvml import NVMLError +def check_cuda_device(default_value): + """ + wraps a function and returns the default value instead of running the + wrapped function if cuda isn't available or the device is auto + :param default_value: + :return: + """ + + def actual_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + device = kwargs.get("device", args[0] if args else None) + + if not torch.cuda.is_available() or device == "auto": + return default_value + + return func(*args, **kwargs) + + return wrapper + + return actual_decorator + + +@check_cuda_device(0.0) def gpu_memory_usage(device=0): return torch.cuda.memory_allocated(device) / 1024.0**3 +@check_cuda_device((0.0, 0.0, 0.0)) def gpu_memory_usage_all(device=0): usage = torch.cuda.memory_allocated(device) / 1024.0**3 reserved = torch.cuda.memory_reserved(device) / 1024.0**3 @@ -16,6 +42,7 @@ def gpu_memory_usage_all(device=0): return usage, reserved - usage, max(0, smi - reserved) +@check_cuda_device(0.0) def gpu_memory_usage_smi(device=0): if isinstance(device, torch.device): device = device.index