diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 273d1027283..da45e0ac1d8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -108,6 +108,32 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset]) +def _check_coords_dims(shape, coords, dims): + sizes = dict(zip(dims, shape)) + for k, v in coords.items(): + if any(d not in dims for d in v.dims): + raise ValueError( + f"coordinate {k} has dimensions {v.dims}, but these " + "are not a subset of the DataArray " + f"dimensions {dims}" + ) + + for d, s in zip(v.dims, v.shape): + if s != sizes[d]: + raise ValueError( + f"conflicting sizes for dimension {d!r}: " + f"length {sizes[d]} on the data but length {s} on " + f"coordinate {k!r}" + ) + + if k in sizes and v.shape != (sizes[k],): + raise ValueError( + f"coordinate {k!r} is a DataArray dimension, but " + f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} " + "matching the dimension size" + ) + + def _infer_coords_and_dims( shape, coords, dims ) -> tuple[dict[Hashable, Variable], tuple[Hashable, ...]]: @@ -159,29 +185,7 @@ def _infer_coords_and_dims( var.dims = (dim,) new_coords[dim] = var.to_index_variable() - sizes = dict(zip(dims, shape)) - for k, v in new_coords.items(): - if any(d not in dims for d in v.dims): - raise ValueError( - f"coordinate {k} has dimensions {v.dims}, but these " - "are not a subset of the DataArray " - f"dimensions {dims}" - ) - - for d, s in zip(v.dims, v.shape): - if s != sizes[d]: - raise ValueError( - f"conflicting sizes for dimension {d!r}: " - f"length {sizes[d]} on the data but length {s} on " - f"coordinate {k!r}" - ) - - if k in sizes and v.shape != (sizes[k],): - raise ValueError( - f"coordinate {k!r} is a DataArray dimension, but " - f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} " - "matching the dimension size" - ) + _check_coords_dims(shape, new_coords, dims) return new_coords, dims @@ -301,6 +305,11 @@ class DataArray( attrs : dict_like or None, optional Attributes to assign to the new instance. By default, an empty attribute dictionary is initialized. + indexes : py:class:`~xarray.Indexes` or dict-like, optional + A collection of :py:class:`~xarray.indexes.Index` objects and + their coordinates variables. If an empty collection is given, + it will skip the creation of default (pandas) indexes for + dimension coordinates. Examples -------- @@ -389,21 +398,18 @@ def __init__( dims: Hashable | Sequence[Hashable] | None = None, name: Hashable | None = None, attrs: Mapping | None = None, + indexes: Mapping[Any, Index] | None = None, # internal parameters - indexes: dict[Hashable, Index] | None = None, fastpath: bool = False, ) -> None: if fastpath: variable = data assert dims is None assert attrs is None - assert indexes is not None + assert isinstance(indexes, dict) + da_indexes = indexes + da_coords = coords else: - # TODO: (benbovy - explicit indexes) remove - # once it becomes part of the public interface - if indexes is not None: - raise ValueError("Providing explicit indexes is not supported yet") - # try to fill in arguments from data if they weren't supplied if coords is None: @@ -423,21 +429,50 @@ def __init__( if attrs is None and not isinstance(data, PANDAS_TYPES): attrs = getattr(data, "attrs", None) + if indexes is None: + create_default_indexes = True + indexes = Indexes() + elif len(indexes) == 0: + create_default_indexes = False + indexes = Indexes() + else: + create_default_indexes = True + if not isinstance(indexes, Indexes): + raise TypeError( + "non-empty indexes must be an instance of `Indexes`" + ) + elif indexes._index_type != Index: + raise TypeError("indexes must only contain Xarray `Index` objects") + data = _check_data_shape(data, coords, dims) data = as_compatible_data(data) - coords, dims = _infer_coords_and_dims(data.shape, coords, dims) + da_coords, dims = _infer_coords_and_dims(data.shape, coords, dims) variable = Variable(dims, data, attrs, fastpath=True) - indexes, coords = _create_indexes_from_coords(coords) + + if create_default_indexes: + da_indexes, da_coords = _create_indexes_from_coords(da_coords) + else: + da_indexes = {} + + both_indexes_and_coords = set(indexes) & set(da_coords) + if both_indexes_and_coords: + raise ValueError( + f"{both_indexes_and_coords} are found in both indexes and coords" + ) + + _check_coords_dims(data.shape, indexes.variables, dims) + + da_coords.update( + {k: v.copy(deep=False) for k, v in indexes.variables.items()} + ) + da_indexes.update(indexes) # These fully describe a DataArray self._variable = variable - assert isinstance(coords, dict) - self._coords = coords + assert isinstance(da_coords, dict) + self._coords = da_coords self._name = name - - # TODO(shoyer): document this argument, once it becomes part of the - # public interface. - self._indexes = indexes + self._indexes = da_indexes # type: ignore[assignment] self._close = None @@ -3667,6 +3702,28 @@ def reduce( var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs) return self._replace_maybe_drop_dims(var) + def assign_indexes(self, indexes: Indexes[Index]): + """Assign new indexes to this dataarray. + + Returns a new dataarray with all the original data in addition to the new + indexes (and their corresponding coordinates). + + Parameters + ---------- + indexes : :py:class:`~xarray.Indexes`. + A collection of :py:class:`~xarray.indexes.Index` objects + to assign (including their coordinate variables). + + Returns + ------- + assigned : DataArray + A new dataarray with the new indexes and coordinates in addition to + the existing data. + """ + # TODO: check indexes.dims must be a subset of self.dims + ds = self._to_temp_dataset().assign_indexes(indexes) + return self._from_temp_dataset(ds) + def to_pandas(self) -> DataArray | pd.Series | pd.DataFrame: """Convert this array into a pandas object with the same shape. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2e2fd6efa72..a9ebe6c1470 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -452,8 +452,10 @@ class Dataset( Dataset implements the mapping interface with keys given by variable names and values given by DataArray objects for each variable name. - One dimensional variables with name equal to their dimension are - index coordinates used for label based indexing. + By default, pandas indexes are created for one dimensional variables with + name equal to their dimension so those variables can be used as coordinates + for label based indexing. Xarray-compatible indexes may also be provided + via the `indexes` argument. To load data from a file or file-like object, use the `open_dataset` function. @@ -504,6 +506,11 @@ class Dataset( attrs : dict-like, optional Global attributes to save on this dataset. + indexes : py:class:`~xarray.Indexes` or dict-like, optional + A collection of :py:class:`~xarray.indexes.Index` objects and + their coordinates variables. If an empty collection is given, + it will skip the creation of default (pandas) indexes for + dimension coordinates. Examples -------- @@ -563,6 +570,7 @@ class Dataset( precipitation float64 8.326 Attributes: description: Weather related data. + """ _attrs: dict[Hashable, Any] | None @@ -593,14 +601,26 @@ def __init__( data_vars: Mapping[Any, Any] | None = None, coords: Mapping[Any, Any] | None = None, attrs: Mapping[Any, Any] | None = None, + indexes: Mapping[Any, Index] | None = None, ) -> None: - # TODO(shoyer): expose indexes as a public argument in __init__ - if data_vars is None: data_vars = {} if coords is None: coords = {} + if indexes is None: + create_default_indexes = True + indexes = Indexes() + elif len(indexes) == 0: + create_default_indexes = False + indexes = Indexes() + else: + create_default_indexes = True + if not isinstance(indexes, Indexes): + raise TypeError("non-empty indexes must be an instance of `Indexes`") + elif indexes._index_type != Index: + raise TypeError("indexes must only contain Xarray `Index` objects") + both_data_and_coords = set(data_vars) & set(coords) if both_data_and_coords: raise ValueError( @@ -610,17 +630,34 @@ def __init__( if isinstance(coords, Dataset): coords = coords.variables - variables, coord_names, dims, indexes, _ = merge_data_and_coords( - data_vars, coords, compat="broadcast_equals" + variables, coord_names, dims, ds_indexes, _ = merge_data_and_coords( + data_vars, + coords, + compat="broadcast_equals", + create_default_indexes=create_default_indexes, ) + both_indexes_and_coords = set(indexes) & coord_names + if both_indexes_and_coords: + raise ValueError( + f"{both_indexes_and_coords} are found in both indexes and coords" + ) + + variables.update({k: v.copy(deep=False) for k, v in indexes.variables.items()}) + coord_names.update(indexes.variables) + ds_indexes.update(indexes) + + # re-calculate dimensions if indexes are given explicitly + if indexes: + dims = calculate_dimensions(variables) + self._attrs = dict(attrs) if attrs is not None else None self._close = None self._encoding = None self._variables = variables self._coord_names = coord_names self._dims = dims - self._indexes = indexes + self._indexes = ds_indexes @classmethod def load_store(cls: type[T_Dataset], store, decoder=None) -> T_Dataset: @@ -6080,6 +6117,30 @@ def assign( data.update(results) return data + def assign_indexes(self, indexes: Indexes[Index]): + """Assign new indexes to this dataset. + + Returns a new dataset with all the original data in addition to the new + indexes (and their corresponding coordinates). + + Parameters + ---------- + indexes : :py:class:`~xarray.Indexes`. + A collection of :py:class:`~xarray.indexes.Index` objects + to assign (including their coordinate variables). + + Returns + ------- + assigned : Dataset + A new dataset with the new indexes and coordinates in addition to + the existing data. + """ + ds_indexes = Dataset(indexes=indexes) + dropped = self.drop_vars(indexes, errors="ignore") + return dropped.merge( + ds_indexes, compat="minimal", join="override", combine_attrs="no_conflicts" + ) + def to_array( self, dim: Hashable = "variable", name: Hashable | None = None ) -> DataArray: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f3f03c9495b..29dd675480e 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1089,19 +1089,22 @@ def create_default_index_implicit( class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): - """Immutable proxy for Dataset or DataArrary indexes. + """Immutable proxy for Dataset or DataArray indexes. - Keys are coordinate names and values may correspond to either pandas or - xarray indexes. + It is a mapping where keys are coordinate names and values are either pandas + or xarray indexes. - Also provides some utility methods. + It also contains the indexed coordinate variables and provides some utility + methods. """ + _index_type: type[Index] | type[pd.Index] _indexes: dict[Any, T_PandasOrXarrayIndex] _variables: dict[Any, Variable] __slots__ = ( + "_index_type", "_indexes", "_variables", "_dims", @@ -1112,8 +1115,9 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]): def __init__( self, - indexes: dict[Any, T_PandasOrXarrayIndex], - variables: dict[Any, Variable], + indexes: Mapping[Any, T_PandasOrXarrayIndex] | None = None, + variables: Mapping[Any, Variable] | None = None, + index_type: type[Index] | type[pd.Index] = Index, ): """Constructor not for public consumption. @@ -1122,11 +1126,33 @@ def __init__( indexes : dict Indexes held by this object. variables : dict - Indexed coordinate variables in this object. + Indexed coordinate variables in this object. Entries must + match those of `indexes`. + index_type : type + The type of all indexes, i.e., either :py:class:`xarray.indexes.Index` + or :py:class:`pandas.Index`. """ - self._indexes = indexes - self._variables = variables + if indexes is None: + indexes = {} + if variables is None: + variables = {} + + unmatched_keys = set(indexes) ^ set(variables) + if unmatched_keys: + raise ValueError( + f"unmatched keys found in indexes and variables: {unmatched_keys}" + ) + + if any(not isinstance(idx, index_type) for idx in indexes.values()): + index_type_str = f"{index_type.__module__}.{index_type.__name__}" + raise TypeError( + f"values of indexes must all be instances of {index_type_str}" + ) + + self._index_type = index_type + self._indexes = dict(**indexes) + self._variables = dict(**variables) self._dims: Mapping[Hashable, int] | None = None self.__coord_name_id: dict[Any, int] | None = None @@ -1274,7 +1300,7 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]: elif isinstance(idx, Index): indexes[k] = idx.to_pandas_index() - return Indexes(indexes, self._variables) + return Indexes(indexes, self._variables, index_type=pd.Index) def copy_indexes( self, deep: bool = True, memo: dict[int, Any] | None = None @@ -1514,3 +1540,33 @@ def assert_no_index_corrupted( f"the following index built from coordinates {index_names_str}:\n" f"{index}" ) + + +def wrap_pandas_multiindex(midx: pd.MultiIndex, dim: str) -> Indexes: + """Wrap a pandas multi-index as Xarray-compatible indexes + and coordinates. + + This function returns an object that can be directly assigned to a + :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` (via the + ``indexes`` argument of their constructor). + + Parameters + ---------- + midx : :py:class:`pandas.MultiIndex` + The pandas multi-index object to wrap. + dim : str + Dimension name. + + Returns + ------- + indexes : :py:class`~xarray.Indexes` + An object that contains both the wrapped Xarray index and + its coordinate variables (dimension + levels). + + """ + xr_idx = PandasMultiIndex(midx, dim) + + variables = xr_idx.create_variables() + indexes = {k: xr_idx for k in variables} + + return Indexes(indexes=indexes, variables=variables) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 77cfb9bed75..b04fcf57a52 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -23,6 +23,7 @@ from xarray.core.indexes import ( Index, Indexes, + PandasIndex, create_default_index_implicit, filter_indexes_from_coords, indexes_equal, @@ -319,6 +320,7 @@ def merge_collected( def collect_variables_and_indexes( list_of_mappings: list[DatasetLike], indexes: Mapping[Any, Any] | None = None, + create_default_indexes: bool = True, ) -> dict[Hashable, list[MergeElement]]: """Collect variables and indexes from list of mappings of xarray objects. @@ -365,7 +367,7 @@ def append_all(variables, indexes): variable = as_variable(variable, name=name) if name in indexes: append(name, variable, indexes[name]) - elif variable.dims == (name,): + elif variable.dims == (name,) and create_default_indexes: idx, idx_vars = create_default_index_implicit(variable) append_all(idx_vars, {k: idx for k in idx_vars}) else: @@ -567,9 +569,18 @@ def merge_coords( return variables, out_indexes -def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"): +def merge_data_and_coords( + data_vars, + coords, + compat="broadcast_equals", + join="outer", + create_default_indexes=True, +): """Used in Dataset.__init__.""" - indexes, coords = _create_indexes_from_coords(coords, data_vars) + if create_default_indexes: + indexes, coords = _create_indexes_from_coords(coords, data_vars) + else: + indexes = {} objects = [data_vars, coords] explicit_coords = coords.keys() return merge_core( @@ -577,11 +588,14 @@ def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="ou compat, join, explicit_coords=explicit_coords, - indexes=Indexes(indexes, coords), + indexes=Indexes(indexes, {k: coords[k] for k in indexes}), + create_default_indexes=create_default_indexes, ) -def _create_indexes_from_coords(coords, data_vars=None): +def _create_indexes_from_coords( + coords: Mapping[Any, Variable], data_vars: Mapping[Any, Variable] | None = None +) -> tuple[dict[Any, PandasIndex], dict[Any, Variable]]: """Maybe create default indexes from a mapping of coordinates. Return those indexes and updated coordinates. @@ -702,6 +716,7 @@ def merge_core( explicit_coords: Sequence | None = None, indexes: Mapping[Any, Any] | None = None, fill_value: object = dtypes.NA, + create_default_indexes: bool = True, ) -> _MergeResult: """Core logic for merging labeled objects. @@ -727,6 +742,8 @@ def merge_core( may be cast to pandas.Index objects. fill_value : scalar, optional Value to use for newly missing values + create_default_indexes : bool, optional + If True, create default (pandas) indexes for dimension coordinates. Returns ------- @@ -752,7 +769,9 @@ def merge_core( aligned = deep_align( coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value ) - collected = collect_variables_and_indexes(aligned, indexes=indexes) + collected = collect_variables_and_indexes( + aligned, indexes=indexes, create_default_indexes=create_default_indexes + ) prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat) variables, out_indexes = merge_collected( collected, prioritized, compat=compat, combine_attrs=combine_attrs diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index 143d7a58fda..2e114e9854e 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -2,6 +2,11 @@ DataArray objects. """ -from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex +from xarray.core.indexes import ( + Index, + PandasIndex, + PandasMultiIndex, + wrap_pandas_multiindex, +) -__all__ = ["Index", "PandasIndex", "PandasMultiIndex"] +__all__ = ["Index", "PandasIndex", "PandasMultiIndex", "wrap_pandas_multiindex"] diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 3ecfa73cc89..dd4edd59ebe 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -583,7 +583,12 @@ def indexes( _, variables = indexes_and_vars - return Indexes(indexes, variables) + if isinstance(x_idx, Index): + index_type = Index + else: + index_type = pd.Index + + return Indexes(indexes, variables, index_type=index_type) def test_interface(self, unique_indexes, indexes) -> None: x_idx = unique_indexes[0]