Skip to content

Commit

Permalink
Skip coords with scalar value (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
GiannisApost authored and twiecki committed Sep 10, 2024
1 parent e6f9308 commit f68240f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
7 changes: 7 additions & 0 deletions pymc_marketing/mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import numpy.typing as npt
import xarray as xr

from pymc_marketing.mmm.utils import drop_scalar_coords

Values = Sequence[Any] | npt.NDArray[Any]
Coords = dict[str, Values]

Expand Down Expand Up @@ -100,6 +102,8 @@ def plot_hdi(
Figure and the axes
"""
curve = drop_scalar_coords(curve)

hdi_kwargs = hdi_kwargs or {}
conf = az.hdi(curve, **hdi_kwargs)[curve.name]

Expand Down Expand Up @@ -190,6 +194,8 @@ def plot_samples(
Figure and the axes
"""
curve = drop_scalar_coords(curve)

plot_coords = get_plot_coords(
curve.coords,
non_grid_names=non_grid_names.union({"chain", "draw"}),
Expand Down Expand Up @@ -262,6 +268,7 @@ def plot_curve(
Figure and the axes
"""
curve = drop_scalar_coords(curve)

hdi_kwargs = hdi_kwargs or {}
sample_kwargs = sample_kwargs or {}
Expand Down
27 changes: 27 additions & 0 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,30 @@ def create_new_spend_data(
spend,
]
)


def drop_scalar_coords(curve: xr.DataArray) -> xr.DataArray:
"""
Remove scalar coordinates from an xarray DataArray.
This function identifies and removes scalar coordinates from the given
DataArray. Scalar coordinates are those with a single value that are
not part of the DataArray's indexes. The function returns a new DataArray
with the scalar coordinates removed.
Parameters
----------
curve : xr.DataArray
The input DataArray from which scalar coordinates will be removed.
Returns
-------
xr.DataArray
A new DataArray with the identified scalar coordinates removed.
"""
scalar_coords_to_drop = []
for coord, values in curve.coords.items():
if values.size == 1 and coord not in curve.indexes:
scalar_coords_to_drop.append(coord)

return curve.reset_coords(scalar_coords_to_drop, drop=True)
32 changes: 32 additions & 0 deletions tests/mmm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
apply_sklearn_transformer_across_dim,
compute_sigmoid_second_derivative,
create_new_spend_data,
drop_scalar_coords,
estimate_menten_parameters,
estimate_sigmoid_parameters,
find_sigmoid_inflection_point,
Expand Down Expand Up @@ -281,3 +282,34 @@ def test_create_new_spend_data_value_errors() -> None:
one_time=True,
spend_leading_up=np.array([3, 4, 5]),
)


@pytest.fixture
def mock_curve_with_scalars() -> xr.DataArray:
coords = {
"x": [1, 2, 3],
"y": [10, 20, 30],
"scalar1": 42, # Scalar coordinate
"scalar2": 3.14, # Another scalar coordinate
}
data = np.random.rand(3, 3)
return xr.DataArray(data, coords=coords, dims=["x", "y"])


def test_drop_scalar_coords(mock_curve_with_scalars) -> None:
original_curve = mock_curve_with_scalars.copy(deep=True) # Make a deep copy
curve = drop_scalar_coords(mock_curve_with_scalars)

# Ensure scalar coordinates are removed
assert "scalar1" not in curve.coords
assert "scalar2" not in curve.coords

# Ensure other coordinates are still present
assert "x" in curve.coords
assert "y" in curve.coords

# Ensure data shape is unchanged
assert curve.shape == (3, 3)

# Ensure the original DataArray was not modified
xr.testing.assert_identical(mock_curve_with_scalars, original_curve)

0 comments on commit f68240f

Please sign in to comment.