Skip to content

Commit

Permalink
add explicit check for dtype to convert to
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Apr 3, 2020
1 parent df2b46d commit 140004a
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def _convert_to_tensor(data: Any) -> Any:
# is not array of object
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
return torch.from_numpy(data)
elif isinstance(data, torch.Tensor):
return data

raise TypeError("The given type ('%s') cannot be converted to a tensor!" % type(data).__name__)

Expand All @@ -94,6 +96,8 @@ def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) ->
return data.cpu().detach().numpy()
elif isinstance(data, numbers.Number):
return np.array([data])
elif isinstance(data, np.ndarray):
return data

raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)

Expand Down

0 comments on commit 140004a

Please sign in to comment.