From 765d96982f5c2be5b2ab39b284348bc35ff9064c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 28 Mar 2021 17:21:19 +0200 Subject: [PATCH] PyTorch Hub custom model to CUDA device fix Fix for #2630 raised by @Pro100rus32 --- hubconf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index 710882cf158f..0eaf70787e64 100644 --- a/hubconf.py +++ b/hubconf.py @@ -128,7 +128,10 @@ def custom(path_or_model='path/to/model.pt', autoshape=True): hub_model = Model(model.yaml).to(next(model.parameters()).device) # create hub_model.load_state_dict(model.float().state_dict()) # load state_dict hub_model.names = model.names # class names - return hub_model.autoshape() if autoshape else hub_model + if autoshape: + hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS + device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available + return hub_model.to(device) if __name__ == '__main__':