Skip to content

Commit

Permalink
Merge pull request #8 from epistoteles/feature-vmin-vmax
Browse files Browse the repository at this point in the history
Feature vmin vmax
  • Loading branch information
epistoteles committed Aug 7, 2024
2 parents 0c0e92a + d054269 commit f59e21c
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 17 deletions.
Binary file added .github/confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added .github/images.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ np.array([1,2,3]).viz() ❌
tensorhue.viz(np.array([1,2,3])) ✅
```

Pillow images get visualized in RGB using `.viz()`:
## Images

Pillow images can be visualized in RGB using `.viz()`:

```python
from torchvision.datasets import CIFAR10
Expand All @@ -69,6 +71,10 @@ img = dataset[0][0]
img.viz() ✅
```

<div align="center">
<img src="https://github.com/raw/epistoteles/tensorhue/main/.github/images.png" alt="image visualization" width="1000">
</div>

By default, images get downscaled to the size of your terminal, but you can make them even smaller if you want:

```python
Expand Down Expand Up @@ -96,7 +102,7 @@ Alternatively, you can overwrite the default ColorScheme:
tensorhue.set_printoptions(colorscheme=cs)
```

## Advanced colors
## Advanced colormaps and normalization

By default, TensorHue normalizes numerical values between 0 and 1 and then applies the matplotlib colormap. If you want to use diverging colormaps such as `coolwarm` or `bwr` and the value 0 to be mapped to the middle of the colormap, you need to specify the normailzer, e.g. `matplotlib.colors.CenteredNorm`:

Expand All @@ -106,3 +112,15 @@ cs = ColorScheme(colormap=colormaps['bwr'],
normalize=CenteredNorm(vcenter=0))
t.viz(cs)
```

You can also specify the normalization range manually, for example when you want to visualize a confusion matrix where colors should be mapped to the range [0, 1], but the actual values of the tensor in the range [0.01, 0.73]:

```
conf_matrix.viz(vmin=0, vmax=1, scale=3)
```

<div align="center">
<img src="https://github.com/raw/epistoteles/tensorhue/main/.github/confusion_matrix.png" alt="confusion matrix" width="1000">
</div>

The `scale` parameter scales up the 'pixels' of the tensor so that small tensors are easier to view.
2 changes: 1 addition & 1 deletion coverage-badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 24 additions & 4 deletions tensorhue/colors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations
import warnings

from rich.color_triplet import ColorTriplet
from matplotlib import colormaps
from matplotlib.colors import Colormap, Normalize
from matplotlib.colors import Colormap, Normalize, CenteredNorm
import numpy as np


Expand Down Expand Up @@ -81,14 +82,33 @@ def ninf_color(self, value):
self._ninf_color = value
self._colormap.set_under(value.normalized)

def __call__(self, data: np.ndarray) -> np.ndarray:
def __call__(self, data: np.ndarray, **kwargs) -> np.ndarray:
if data.dtype == "bool":
true_values = np.array(self.true_color, dtype=np.uint8)
false_values = np.array(self.false_color, dtype=np.uint8)
return np.where(data[..., np.newaxis], true_values, false_values)
data_noinf = np.where(np.isinf(data), np.nan, data)
self.normalize.vmin = np.nanmin(data_noinf)
self.normalize.vmax = np.nanmax(data_noinf)
if "vmin" not in kwargs:
vmin = np.nanmin(data_noinf)
else:
vmin = float(kwargs["vmin"])
if "vmax" not in kwargs:
vmax = np.nanmax(data_noinf)
else:
vmax = float(kwargs["vmax"])
if isinstance(self.normalize, CenteredNorm):
vcenter = self.normalize.vcenter
diff_vmin = vmin - vcenter
diff_vmax = vmax - vcenter
max_abs_diff = max(abs(diff_vmin), abs(diff_vmax))
vmin = vcenter - max_abs_diff
vmax = vcenter + max_abs_diff
if "vmin" in kwargs and "vmax" in kwargs:
warnings.warn(
f"You shouldn't specify both 'vmin' and 'vmax' when using CenteredNorm. 'vmin' and 'vmax' must be symmetric around 'vcenter' and are thus inferred from a single value. Using: {vmin=}, {vcenter=}, {vmax=}."
)
self.normalize.vmin = vmin
self.normalize.vmax = vmax
return self.colormap(self.normalize(data), bytes=True)

def __repr__(self):
Expand Down
20 changes: 11 additions & 9 deletions tensorhue/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@
from tensorhue.connectors._numpy import NumpyArrayWrapper


def viz(tensor, *args, **kwargs):
def viz(tensor, **kwargs):
if isinstance(tensor, np.ndarray):
tensor = NumpyArrayWrapper(tensor)
tensor.viz(*args, **kwargs) # pylint: disable=no-member
tensor.viz(**kwargs) # pylint: disable=no-member
else:
try:
tensor.viz(*args, **kwargs)
tensor.viz(**kwargs)
except Exception as e:
raise NotImplementedError(
f"TensorHue does not support type {type(tensor)}. Raise an issue if you need to visualize them. Alternatively, check if you imported tensorhue *after* your other library."
f"TensorHue currently does not support type {type(tensor)}. Please raise an issue if you want to visualize them. Alternatively, check if you imported tensorhue *after* your other library."
) from e


def _viz(self, colorscheme: ColorScheme = None, legend: bool = True, scale: int = 1):
def _viz(self, colorscheme: ColorScheme = None, legend: bool = True, scale: int = 1, **kwargs):
"""
Prints a tensor using colored Unicode art representation.
Expand All @@ -30,6 +30,7 @@ def _viz(self, colorscheme: ColorScheme = None, legend: bool = True, scale: int
Defaults to None, which means the global default color scheme is used.
legend (bool, optional): Whether or not to include legend information (like the shape)
scale (int, optional): Scales the size of the entire tensor up, making the unicode 'pixels' larger.
**kwargs: Additional keyword arguments that are passed to the underlying viz function (vmin or vmax)
"""
if not isinstance(scale, int):
raise ValueError("scale must be an integer.")
Expand All @@ -45,12 +46,12 @@ def _viz(self, colorscheme: ColorScheme = None, legend: bool = True, scale: int
self = self[np.newaxis, :]
elif ndim > 2:
raise NotImplementedError(
"Visualization for tensors with more than 2 dimensions is under development. Please slice them for now."
"Visualization of tensors with more than 2 dimensions is under development. Please slice them for now."
)

self = np.repeat(np.repeat(self, scale, axis=1), scale, axis=0)

result_lines = _viz_2d(self, colorscheme)
result_lines = _viz_2d(self, colorscheme, **kwargs)

if legend:
result_lines.append(f"[italic]shape = {shape}[/]")
Expand All @@ -59,13 +60,14 @@ def _viz(self, colorscheme: ColorScheme = None, legend: bool = True, scale: int
c.print("\n".join(result_lines))


def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme = None) -> list[str]:
def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme = None, **kwargs) -> list[str]:
"""
Constructs a list of rich-compatible strings out of a 2D numpy array.
Args:
array_2d (np.ndarray): The 2-dimensional numpy array (or 3-dimensional if the values are already RGB).
colorscheme (ColorScheme): The color scheme to use. If None, the array must be 3-dimensional (already RGB values).
**kwargs: Additional keyword arguments that are passed to the underlying viz function (vmin or vmax)
"""
terminal_width = get_terminal_size().columns
shape = array_2d.shape
Expand All @@ -85,7 +87,7 @@ def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme = None) -> list[str]:
slice_right = colors_right = False

if colorscheme is not None:
colors_left = colorscheme(array_2d[:, :slice_left])[..., :3]
colors_left = colorscheme(array_2d[:, :slice_left], **kwargs)[..., :3]
else:
assert (
array_2d.ndim == 3 and array_2d.shape[-1] == 3
Expand Down
51 changes: 50 additions & 1 deletion tests/test_colors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from rich.color_triplet import ColorTriplet
import numpy as np
from matplotlib.colors import Colormap
from matplotlib.colors import Colormap, CenteredNorm
from matplotlib import colormaps
from tensorhue.colors import ColorScheme, COLORS

Expand Down Expand Up @@ -39,3 +39,52 @@ def test_ColorScheme():

bool_array = np.array([True, False])
assert np.array_equal(cs(bool_array), np.array([cs.true_color, cs.false_color]))


def test_vmin_vmax():
cs = ColorScheme(colormap=colormaps["magma"])

values1 = np.array([-0.5, 0.0, 0.5, 0.75])

result1 = cs(values1)
assert np.array_equal(
result1,
np.array([[0, 0, 3, 255], [140, 41, 128, 255], [253, 159, 108, 255], [251, 252, 191, 255]], dtype=np.uint8),
)

result2 = cs(values1, vmin=-0.5)
assert np.array_equal(result1, result2)

result3 = cs(values1, vmax=0.75)
assert np.array_equal(result1, result3)

result4 = cs(values1, vmin=-0.5, vmax=0.75)
assert np.array_equal(result1, result4)

result5 = cs(values1, vmin=-1)
assert np.array_equal(
result5,
np.array([[94, 23, 127, 255], [211, 66, 109, 255], [254, 187, 128, 255], [251, 252, 191, 255]], dtype=np.uint8),
)

result6 = cs(values1, vmax=0.4)
assert np.array_equal(
result6,
np.array([[0, 0, 3, 255], [205, 63, 112, 255], [255, 255, 255, 255], [255, 255, 255, 255]], dtype=np.uint8),
)

cs = ColorScheme(colormap=colormaps["bwr"], normalize=CenteredNorm())

result7 = cs(values1)
assert np.array_equal(
result7,
np.array([[84, 84, 255, 255], [255, 254, 254, 255], [255, 84, 84, 255], [255, 0, 0, 255]], dtype=np.uint8),
)

result8 = cs(values1, vmin=-1)
assert np.array_equal(
result8,
np.array(
[[128, 128, 255, 255], [255, 254, 254, 255], [255, 126, 126, 255], [255, 62, 62, 255]], dtype=np.uint8
),
)

0 comments on commit f59e21c

Please sign in to comment.