diff --git a/export.py b/export.py index a825a73b2d3c..574fee050b94 100644 --- a/export.py +++ b/export.py @@ -486,7 +486,7 @@ def run( if half: assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0' assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both' - model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model + model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model nc, names = model.nc, model.names # number of classes, class names # Checks diff --git a/models/common.py b/models/common.py index bad55b4024c0..66467a0ab1b7 100644 --- a/models/common.py +++ b/models/common.py @@ -331,7 +331,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, names = yaml.safe_load(f)['names'] if pt: # PyTorch - model = attempt_load(weights if isinstance(weights, list) else w, map_location=device) + model = attempt_load(weights if isinstance(weights, list) else w, device=device) stride = max(int(model.stride.max()), 32) # model stride names = model.module.names if hasattr(model, 'module') else model.names # get class names model.half() if fp16 else model.float() diff --git a/models/experimental.py b/models/experimental.py index 7bf249e80984..6ed528a335d2 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -71,14 +71,14 @@ def forward(self, x, augment=False, profile=False, visualize=False): return y, None # inference, train output -def attempt_load(weights, map_location=None, inplace=True, fuse=True): +def attempt_load(weights, device=None, inplace=True, fuse=True): from models.yolo import Detect, Model # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: - ckpt = torch.load(attempt_download(w), map_location=map_location) # load - ckpt = (ckpt.get('ema') or ckpt['model']).float() # FP32 model + ckpt = torch.load(attempt_download(w)) + ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode # Compatibility updates diff --git a/models/tf.py b/models/tf.py index 202a957e3e63..b0d98cc2a3a9 100644 --- a/models/tf.py +++ b/models/tf.py @@ -536,7 +536,7 @@ def run( ): # PyTorch model im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image - model = attempt_load(weights, map_location=torch.device('cpu'), inplace=True, fuse=False) + model = attempt_load(weights, device=torch.device('cpu'), inplace=True, fuse=False) _ = model(im) # inference model.info() diff --git a/utils/torch_utils.py b/utils/torch_utils.py index ca32b1999e81..d11df8337300 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -54,7 +54,8 @@ def select_device(device='', batch_size=0, newline=True): s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} ' device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0' cpu = device == 'cpu' - if cpu: + mps = device == 'mps' # Apple Metal Performance Shaders (MPS) + if cpu or mps: os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False elif device: # non-cpu device requested os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() @@ -71,13 +72,15 @@ def select_device(device='', batch_size=0, newline=True): for i, d in enumerate(devices): p = torch.cuda.get_device_properties(i) s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB + elif mps: + s += 'MPS\n' else: s += 'CPU\n' if not newline: s = s.rstrip() LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe - return torch.device('cuda:0' if cuda else 'cpu') + return torch.device('cuda:0' if cuda else 'mps' if mps else 'cpu') def time_sync():