From 140004a0ff634f075d869c292e8fa0cedfc96b28 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Fri, 3 Apr 2020 14:29:20 +0200 Subject: [PATCH] add explicit check for dtype to convert to --- pytorch_lightning/metrics/converters.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/metrics/converters.py b/pytorch_lightning/metrics/converters.py index 629da04c7db21..8162876fc3b00 100644 --- a/pytorch_lightning/metrics/converters.py +++ b/pytorch_lightning/metrics/converters.py @@ -77,6 +77,8 @@ def _convert_to_tensor(data: Any) -> Any: # is not array of object elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None: return torch.from_numpy(data) + elif isinstance(data, torch.Tensor): + return data raise TypeError("The given type ('%s') cannot be converted to a tensor!" % type(data).__name__) @@ -94,6 +96,8 @@ def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> return data.cpu().detach().numpy() elif isinstance(data, numbers.Number): return np.array([data]) + elif isinstance(data, np.ndarray): + return data raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)