From 058fe760b86e370e2a6fa5b807fde091fd7d23c4 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Mon, 5 Dec 2022 23:57:11 -0800 Subject: [PATCH 1/2] improve error msg of invalid input types --- captum/_utils/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index bba0ea293b..82e0c2b8bc 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -177,7 +177,8 @@ def _format_tensor_into_tuples( if not isinstance(inputs, tuple): assert isinstance( inputs, torch.Tensor - ), "`inputs` must have type " "torch.Tensor but {} found: ".format(type(inputs)) + ), "`inputs` must be a torch.Tensor or a tuple[torch.Tensor] " \ + f"but found: {type(inputs)}" inputs = (inputs,) return inputs From d37902aa81b7c6d372f87d3adfb9c1bd5b785929 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Tue, 6 Dec 2022 22:30:53 -0800 Subject: [PATCH 2/2] format --- captum/_utils/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 82e0c2b8bc..3a80760c91 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -175,10 +175,10 @@ def _format_tensor_into_tuples( if inputs is None: return None if not isinstance(inputs, tuple): - assert isinstance( - inputs, torch.Tensor - ), "`inputs` must be a torch.Tensor or a tuple[torch.Tensor] " \ + assert isinstance(inputs, torch.Tensor), ( + "`inputs` must be a torch.Tensor or a tuple[torch.Tensor] " f"but found: {type(inputs)}" + ) inputs = (inputs,) return inputs