Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip coords with scalar value #868

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pymc_marketing/mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import numpy.typing as npt
import xarray as xr

from pymc_marketing.mmm.utils import drop_scalar_coords

Values = Sequence[Any] | npt.NDArray[Any]
Coords = dict[str, Values]

Expand Down Expand Up @@ -100,6 +102,8 @@ def plot_hdi(
Figure and the axes

"""
curve = drop_scalar_coords(curve)

hdi_kwargs = hdi_kwargs or {}
conf = az.hdi(curve, **hdi_kwargs)[curve.name]

Expand Down Expand Up @@ -190,6 +194,8 @@ def plot_samples(
Figure and the axes

"""
curve = drop_scalar_coords(curve)

plot_coords = get_plot_coords(
curve.coords,
non_grid_names=non_grid_names.union({"chain", "draw"}),
Expand Down Expand Up @@ -262,6 +268,7 @@ def plot_curve(
Figure and the axes

"""
curve = drop_scalar_coords(curve)

hdi_kwargs = hdi_kwargs or {}
sample_kwargs = sample_kwargs or {}
Expand Down
27 changes: 27 additions & 0 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,30 @@ def create_new_spend_data(
spend,
]
)


def drop_scalar_coords(curve: xr.DataArray) -> xr.DataArray:
"""
Remove scalar coordinates from an xarray DataArray.

This function identifies and removes scalar coordinates from the given
DataArray. Scalar coordinates are those with a single value that are
not part of the DataArray's indexes. The function returns a new DataArray
with the scalar coordinates removed.

Parameters
----------
curve : xr.DataArray
The input DataArray from which scalar coordinates will be removed.

Returns
-------
xr.DataArray
A new DataArray with the identified scalar coordinates removed.
"""
scalar_coords_to_drop = []
for coord, values in curve.coords.items():
if values.size == 1 and coord not in curve.indexes:
scalar_coords_to_drop.append(coord)

return curve.reset_coords(scalar_coords_to_drop, drop=True)
32 changes: 32 additions & 0 deletions tests/mmm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
apply_sklearn_transformer_across_dim,
compute_sigmoid_second_derivative,
create_new_spend_data,
drop_scalar_coords,
estimate_menten_parameters,
estimate_sigmoid_parameters,
find_sigmoid_inflection_point,
Expand Down Expand Up @@ -281,3 +282,34 @@ def test_create_new_spend_data_value_errors() -> None:
one_time=True,
spend_leading_up=np.array([3, 4, 5]),
)


@pytest.fixture
def mock_curve_with_scalars() -> xr.DataArray:
coords = {
"x": [1, 2, 3],
"y": [10, 20, 30],
"scalar1": 42, # Scalar coordinate
"scalar2": 3.14, # Another scalar coordinate
}
data = np.random.rand(3, 3)
return xr.DataArray(data, coords=coords, dims=["x", "y"])


def test_drop_scalar_coords(mock_curve_with_scalars) -> None:
original_curve = mock_curve_with_scalars.copy(deep=True) # Make a deep copy
curve = drop_scalar_coords(mock_curve_with_scalars)

# Ensure scalar coordinates are removed
assert "scalar1" not in curve.coords
assert "scalar2" not in curve.coords

# Ensure other coordinates are still present
assert "x" in curve.coords
assert "y" in curve.coords

# Ensure data shape is unchanged
assert curve.shape == (3, 3)

# Ensure the original DataArray was not modified
xr.testing.assert_identical(mock_curve_with_scalars, original_curve)
Loading