Skip to content

Commit

Permalink
Refactor colorbar into an ipywidgets subclass (#1640)
Browse files Browse the repository at this point in the history
* Refactor colorbar into an ipywidgets subclass.

* Add map_widgets to docs

* Fix vertical colorbar bug

---------

Co-authored-by: Qiusheng Wu <giswqs@gmail.com>
  • Loading branch information
naschmitz and giswqs authored Jul 31, 2023
1 parent b21beb2 commit 3223cd9
Show file tree
Hide file tree
Showing 6 changed files with 390 additions and 131 deletions.
3 changes: 3 additions & 0 deletions docs/map_widgets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# map_widgets module

::: geemap.map_widgets
10 changes: 8 additions & 2 deletions geemap/ee_tile_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def _ee_object_to_image(ee_object, vis_params):


def _validate_palette(palette):
if isinstance(palette, tuple):
palette = list(palette)
if isinstance(palette, box.Box):
if "default" not in palette:
raise ValueError("The provided palette Box object is invalid.")
Expand Down Expand Up @@ -92,7 +94,9 @@ def __init__(
shown (bool, optional): A flag indicating whether the layer should be on by default. Defaults to True.
opacity (float, optional): The layer's opacity represented as a number between 0 and 1. Defaults to 1.
"""
self.url_format = _get_tile_url_format(ee_object, _validate_vis_params(vis_params))
self.url_format = _get_tile_url_format(
ee_object, _validate_vis_params(vis_params)
)
super().__init__(
tiles=self.url_format,
attr="Google Earth Engine",
Expand Down Expand Up @@ -127,7 +131,9 @@ def __init__(
shown (bool, optional): A flag indicating whether the layer should be on by default. Defaults to True.
opacity (float, optional): The layer's opacity represented as a number between 0 and 1. Defaults to 1.
"""
self.url_format = _get_tile_url_format(ee_object, _validate_vis_params(vis_params))
self.url_format = _get_tile_url_format(
ee_object, _validate_vis_params(vis_params)
)
super().__init__(
url=self.url_format,
attribution="Google Earth Engine",
Expand Down
149 changes: 20 additions & 129 deletions geemap/geemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from .common import *
from .conversion import *
from .ee_tile_layers import *
from .timelapse import *
from . import map_widgets
from .plot import *
from .timelapse import *

from . import examples

Expand Down Expand Up @@ -1014,7 +1015,7 @@ def add_colorbar(
layer_name=None,
font_size=9,
axis_off=False,
max_width="270px",
max_width=None,
**kwargs,
):
"""Add a matplotlib colorbar to the map
Expand All @@ -1030,142 +1031,32 @@ def add_colorbar(
layer_name (str, optional): The layer name associated with the colorbar. Defaults to None.
font_size (int, optional): Font size for the colorbar. Defaults to 9.
axis_off (bool, optional): Whether to turn off the axis. Defaults to False.
max_width (str, optional): Maximum width of the colorbar in pixels. Defaults to "300px".
max_width (str, optional): Maximum width of the colorbar in pixels. Defaults to None.
Raises:
TypeError: If the vis_params is not a dictionary.
ValueError: If the orientation is not either horizontal or vertical.
ValueError: If the provided min value is not scalar type.
ValueError: If the provided max value is not scalar type.
ValueError: If the provided opacity value is not scalar type.
ValueError: If cmap or palette is not provided.
TypeError: If the provided min value is not scalar type.
TypeError: If the provided max value is not scalar type.
TypeError: If the provided opacity value is not scalar type.
TypeError: If cmap or palette is not provided.
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

if isinstance(vis_params, list):
vis_params = {"palette": vis_params}
elif isinstance(vis_params, tuple):
vis_params = {"palette": list(vis_params)}
elif vis_params is None:
vis_params = {}

if "colors" in kwargs and isinstance(kwargs["colors"], list):
vis_params["palette"] = kwargs["colors"]

if "colors" in kwargs and isinstance(kwargs["colors"], tuple):
vis_params["palette"] = list(kwargs["colors"])

if "vmin" in kwargs:
vis_params["min"] = kwargs["vmin"]
del kwargs["vmin"]

if "vmax" in kwargs:
vis_params["max"] = kwargs["vmax"]
del kwargs["vmax"]

if "caption" in kwargs:
label = kwargs["caption"]
del kwargs["caption"]

if not isinstance(vis_params, dict):
raise TypeError("The vis_params must be a dictionary.")

if orientation not in ["horizontal", "vertical"]:
raise ValueError("The orientation must be either horizontal or vertical.")

if orientation == "horizontal":
width, height = 3.0, 0.3
else:
width, height = 0.3, 3.0

if "width" in kwargs:
width = kwargs["width"]
kwargs.pop("width")

if "height" in kwargs:
height = kwargs["height"]
kwargs.pop("height")

vis_keys = list(vis_params.keys())

if "min" in vis_params:
vmin = vis_params["min"]
if type(vmin) not in (int, float):
raise ValueError("The provided min value must be scalar type.")
else:
vmin = 0

if "max" in vis_params:
vmax = vis_params["max"]
if type(vmax) not in (int, float):
raise ValueError("The provided max value must be scalar type.")
else:
vmax = 1

if "opacity" in vis_params:
alpha = vis_params["opacity"]
if type(alpha) not in (int, float):
raise ValueError("The provided opacity value must be type scalar.")
elif "alpha" in kwargs:
alpha = kwargs["alpha"]
else:
alpha = 1

if cmap is not None:
cmap = mpl.pyplot.get_cmap(cmap)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

if "palette" in vis_keys:
hexcodes = to_hex_colors(check_cmap(vis_params["palette"]))
if discrete:
cmap = mpl.colors.ListedColormap(hexcodes)
vals = np.linspace(vmin, vmax, cmap.N + 1)
norm = mpl.colors.BoundaryNorm(vals, cmap.N)

else:
cmap = mpl.colors.LinearSegmentedColormap.from_list(
"custom", hexcodes, N=256
)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

elif cmap is not None:
cmap = mpl.pyplot.get_cmap(cmap)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

else:
raise ValueError(
'cmap keyword or "palette" key in vis_params must be provided.'
)

fig, ax = plt.subplots(figsize=(width, height))
cb = mpl.colorbar.ColorbarBase(
ax, norm=norm, alpha=alpha, cmap=cmap, orientation=orientation, **kwargs
colorbar = map_widgets.Colorbar(
vis_params,
cmap,
discrete,
label,
orientation,
transparent_bg,
font_size,
axis_off,
max_width,
**kwargs,
)

if label is not None:
cb.set_label(label, fontsize=font_size)
elif "bands" in vis_keys:
cb.set_label(vis_params["bands"], fontsize=font_size)

if axis_off:
ax.set_axis_off()
ax.tick_params(labelsize=font_size)

# set the background color to transparent
if transparent_bg:
fig.patch.set_alpha(0.0)

output = widgets.Output(layout=widgets.Layout(width=max_width))
colormap_ctrl = ipyleaflet.WidgetControl(
widget=output,
position=position,
transparent_bg=transparent_bg,
widget=colorbar, position=position, transparent_bg=transparent_bg
)
with output:
output.outputs = ()
plt.show()

self._colorbar = colormap_ctrl
if layer_name in self.ee_layer_names:
Expand Down
136 changes: 136 additions & 0 deletions geemap/map_widgets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Various ipywidgets that can be added to a map."""

import ipywidgets

from . import common


class Colorbar(ipywidgets.Output):
"""A matplotlib colorbar widget that can be added to the map."""

def __init__(
self,
vis_params=None,
cmap="gray",
discrete=False,
label=None,
orientation="horizontal",
transparent_bg=False,
font_size=9,
axis_off=False,
max_width=None,
**kwargs,
):
"""Add a matplotlib colorbar to the map.
Args:
vis_params (dict): Visualization parameters as a dictionary. See https://developers.google.com/earth-engine/guides/image_visualization for options.
cmap (str, optional): Matplotlib colormap. Defaults to "gray". See https://matplotlib.org/3.3.4/tutorials/colors/colormaps.html#sphx-glr-tutorials-colors-colormaps-py for options.
discrete (bool, optional): Whether to create a discrete colorbar. Defaults to False.
label (str, optional): Label for the colorbar. Defaults to None.
orientation (str, optional): Orientation of the colorbar, such as "vertical" and "horizontal". Defaults to "horizontal".
transparent_bg (bool, optional): Whether to use transparent background. Defaults to False.
font_size (int, optional): Font size for the colorbar. Defaults to 9.
axis_off (bool, optional): Whether to turn off the axis. Defaults to False.
max_width (str, optional): Maximum width of the colorbar in pixels. Defaults to None.
Raises:
TypeError: If the vis_params is not a dictionary.
ValueError: If the orientation is not either horizontal or vertical.
ValueError: If the provided min value is not scalar type.
ValueError: If the provided max value is not scalar type.
ValueError: If the provided opacity value is not scalar type.
ValueError: If cmap or palette is not provided.
"""

import matplotlib # pylint: disable=import-outside-toplevel
import numpy # pylint: disable=import-outside-toplevel

if max_width is None:
if orientation == "horizontal":
max_width = "270px"
else:
max_width = "100px"

if isinstance(vis_params, (list, tuple)):
vis_params = {"palette": list(vis_params)}
elif not vis_params:
vis_params = {}

if not isinstance(vis_params, dict):
raise TypeError("The vis_params must be a dictionary.")

if isinstance(kwargs.get("colors"), (list, tuple)):
vis_params["palette"] = list(kwargs["colors"])

width, height = self._get_dimensions(orientation, kwargs)

vmin = vis_params.get("min", kwargs.pop("vmin", 0))
if type(vmin) not in (int, float):
raise TypeError("The provided min value must be scalar type.")

vmax = vis_params.get("max", kwargs.pop("mvax", 1))
if type(vmax) not in (int, float):
raise TypeError("The provided max value must be scalar type.")

alpha = vis_params.get("opacity", kwargs.pop("alpha", 1))
if type(alpha) not in (int, float):
raise TypeError("The provided opacity or alpha value must be type scalar.")

if "palette" in vis_params.keys():
hexcodes = common.to_hex_colors(common.check_cmap(vis_params["palette"]))
if discrete:
cmap = matplotlib.colors.ListedColormap(hexcodes)
linspace = numpy.linspace(vmin, vmax, cmap.N + 1)
norm = matplotlib.colors.BoundaryNorm(linspace, cmap.N)
else:
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
"custom", hexcodes, N=256
)
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
elif cmap:
cmap = matplotlib.pyplot.get_cmap(cmap)
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
else:
raise ValueError(
'cmap keyword or "palette" key in vis_params must be provided.'
)

fig, ax = matplotlib.pyplot.subplots(figsize=(width, height))
cb = matplotlib.colorbar.ColorbarBase(
ax,
norm=norm,
alpha=alpha,
cmap=cmap,
orientation=orientation,
**kwargs,
)

label = label or vis_params.get("bands") or kwargs.pop("caption", None)
if label:
cb.set_label(label, fontsize=font_size)

if axis_off:
ax.set_axis_off()
ax.tick_params(labelsize=font_size)

# Set the background color to transparent.
if transparent_bg:
fig.patch.set_alpha(0.0)

super().__init__(layout=ipywidgets.Layout(width=max_width))
with self:
self.outputs = ()
matplotlib.pyplot.show()

def _get_dimensions(self, orientation, kwargs):
default_dims = {"horizontal": (3.0, 0.3), "vertical": (0.3, 3.0)}
if orientation in default_dims:
default = default_dims[orientation]
return (
kwargs.get("width", default[0]),
kwargs.get("height", default[1]),
)
raise ValueError(
f"orientation must be one of [{', '.join(default_dims.keys())}]."
)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ nav:
- geemap module: geemap.md
- kepler module: kepler.md
- legends module: legends.md
- map_widgets module: map_widgets.md
- ml module: ml.md
- osm module: osm.md
- plot module: plot.md
Expand Down
Loading

0 comments on commit 3223cd9

Please sign in to comment.