diff --git a/doc/api.rst b/doc/api.rst index ceab7dcc976..9cb02441d37 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -126,6 +126,7 @@ Indexing Dataset.isel Dataset.sel Dataset.drop_sel + Dataset.drop_isel Dataset.head Dataset.tail Dataset.thin @@ -307,6 +308,7 @@ Indexing DataArray.isel DataArray.sel DataArray.drop_sel + DataArray.drop_isel DataArray.head DataArray.tail DataArray.thin diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6f47446a2eb..59a9fc45e8d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -39,8 +39,8 @@ Breaking changes always be set such that ``int64`` values can be used. In the past, no units finer than "seconds" were chosen, which would sometimes mean that ``float64`` values were required, which would lead to inaccurate I/O round-trips. -- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull: `4725`). - By `Aureliana Barghini `_ +- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull:`4725`). + By `Aureliana Barghini `_. Deprecations ~~~~~~~~~~~~ @@ -87,6 +87,7 @@ Bug fixes - Expand user directory paths (e.g. ``~/``) in :py:func:`open_mfdataset` and :py:meth:`Dataset.to_zarr` (:issue:`4783`, :pull:`4795`). By `Julien Seguinot `_. +- Add :py:meth:`Dataset.drop_isel` and :py:meth:`DataArray.drop_isel` (:issue:`4658`, :pull:`4819`). By `Daniel Mesejo `_. Documentation ~~~~~~~~~~~~~ @@ -115,6 +116,8 @@ Internal Changes By `Maximilian Roos `_. - Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn `_. +- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release + all resources. (:pull:`#4809`), By `Alessandro Amici `_. .. _whats-new.0.16.2: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 4958062a262..81314588784 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -522,7 +522,7 @@ def maybe_decode_store(store, chunks): else: ds2 = ds - ds2._file_obj = ds._file_obj + ds2.set_close(ds._close) return ds2 filename_or_obj = _normalize_path(filename_or_obj) @@ -701,7 +701,7 @@ def open_dataarray( else: (data_array,) = dataset.data_vars.values() - data_array._file_obj = dataset._file_obj + data_array.set_close(dataset._close) # Reset names if they were changed during saving # to ensure that we can 'roundtrip' perfectly @@ -715,17 +715,6 @@ def open_dataarray( return data_array -class _MultiFileCloser: - __slots__ = ("file_objs",) - - def __init__(self, file_objs): - self.file_objs = file_objs - - def close(self): - for f in self.file_objs: - f.close() - - def open_mfdataset( paths, chunks=None, @@ -918,14 +907,14 @@ def open_mfdataset( getattr_ = getattr datasets = [open_(p, **open_kwargs) for p in paths] - file_objs = [getattr_(ds, "_file_obj") for ds in datasets] + closers = [getattr_(ds, "_close") for ds in datasets] if preprocess is not None: datasets = [preprocess(ds) for ds in datasets] if parallel: # calling compute here will return the datasets/file_objs lists, # the underlying datasets will still be stored as dask arrays - datasets, file_objs = dask.compute(datasets, file_objs) + datasets, closers = dask.compute(datasets, closers) # Combine all datasets, closing them in case of a ValueError try: @@ -963,7 +952,11 @@ def open_mfdataset( ds.close() raise - combined._file_obj = _MultiFileCloser(file_objs) + def multi_file_closer(): + for closer in closers: + closer() + + combined.set_close(multi_file_closer) # read global attributes from the attrs_file or from the first dataset if attrs_file is not None: diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index 0f98291983d..d31fc9ea773 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -90,7 +90,7 @@ def _dataset_from_backend_dataset( **extra_tokens, ) - ds._file_obj = backend_ds._file_obj + ds.set_close(backend_ds._close) # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index d4933e370c7..4a0ac7d67f9 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -5,10 +5,23 @@ from ..core import indexing from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, +) from .locks import SerializableLock, ensure_lock from .store import open_backend_dataset_store +try: + import cfgrib + + has_cfgrib = True +except ModuleNotFoundError: + has_cfgrib = False + + # FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe # in most circumstances. See: # https://confluence.ecmwf.int/display/ECC/Frequently+Asked+Questions @@ -38,7 +51,6 @@ class CfGribDataStore(AbstractDataStore): """ def __init__(self, filename, lock=None, **backend_kwargs): - import cfgrib if lock is None: lock = ECCODES_LOCK @@ -129,3 +141,7 @@ def open_backend_dataset_cfgrib( cfgrib_backend = BackendEntrypoint( open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib ) + + +if has_cfgrib: + BACKEND_ENTRYPOINTS["cfgrib"] = cfgrib_backend diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 72a63957662..adb70658fab 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,6 +1,7 @@ import logging import time import traceback +from typing import Dict import numpy as np @@ -349,3 +350,6 @@ def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=No self.open_dataset = open_dataset self.open_dataset_parameters = open_dataset_parameters self.guess_can_open = guess_can_open + + +BACKEND_ENTRYPOINTS: Dict[str, BackendEntrypoint] = {} diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index b2996369ee7..562600de4b6 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,7 +8,12 @@ from ..core import indexing from ..core.utils import FrozenDict, is_remote_uri, read_magic_number from ..core.variable import Variable -from .common import BackendEntrypoint, WritableCFDataStore, find_root_and_group +from .common import ( + BACKEND_ENTRYPOINTS, + BackendEntrypoint, + WritableCFDataStore, + find_root_and_group, +) from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( @@ -20,6 +25,13 @@ ) from .store import open_backend_dataset_store +try: + import h5netcdf + + has_h5netcdf = True +except ModuleNotFoundError: + has_h5netcdf = False + class H5NetCDFArrayWrapper(BaseNetCDF4Array): def get_array(self, needs_lock=True): @@ -85,8 +97,6 @@ class H5NetCDFStore(WritableCFDataStore): def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False): - import h5netcdf - if isinstance(manager, (h5netcdf.File, h5netcdf.Group)): if group is None: root, group = find_root_and_group(manager) @@ -122,7 +132,6 @@ def open( invalid_netcdf=None, phony_dims=None, ): - import h5netcdf if isinstance(filename, bytes): raise ValueError( @@ -375,3 +384,6 @@ def open_backend_dataset_h5netcdf( h5netcdf_backend = BackendEntrypoint( open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf ) + +if has_h5netcdf: + BACKEND_ENTRYPOINTS["h5netcdf"] = h5netcdf_backend diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 0e35270ea9a..5bb4eec837b 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -12,6 +12,7 @@ from ..core.utils import FrozenDict, close_on_error, is_remote_uri from ..core.variable import Variable from .common import ( + BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, WritableCFDataStore, @@ -23,6 +24,14 @@ from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable from .store import open_backend_dataset_store +try: + import netCDF4 + + has_netcdf4 = True +except ModuleNotFoundError: + has_netcdf4 = False + + # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. _endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"} @@ -298,7 +307,6 @@ class NetCDF4DataStore(WritableCFDataStore): def __init__( self, manager, group=None, mode=None, lock=NETCDF4_PYTHON_LOCK, autoclose=False ): - import netCDF4 if isinstance(manager, netCDF4.Dataset): if group is None: @@ -335,7 +343,6 @@ def open( lock_maker=None, autoclose=False, ): - import netCDF4 if isinstance(filename, pathlib.Path): filename = os.fspath(filename) @@ -563,3 +570,7 @@ def open_backend_dataset_netcdf4( netcdf4_backend = BackendEntrypoint( open_dataset=open_backend_dataset_netcdf4, guess_can_open=guess_can_open_netcdf4 ) + + +if has_netcdf4: + BACKEND_ENTRYPOINTS["netcdf4"] = netcdf4_backend diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index d5799a78f91..6d3ec7e7da5 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -2,33 +2,11 @@ import inspect import itertools import logging -import typing as T import warnings import pkg_resources -from .cfgrib_ import cfgrib_backend -from .common import BackendEntrypoint -from .h5netcdf_ import h5netcdf_backend -from .netCDF4_ import netcdf4_backend -from .pseudonetcdf_ import pseudonetcdf_backend -from .pydap_ import pydap_backend -from .pynio_ import pynio_backend -from .scipy_ import scipy_backend -from .store import store_backend -from .zarr import zarr_backend - -BACKEND_ENTRYPOINTS: T.Dict[str, BackendEntrypoint] = { - "store": store_backend, - "netcdf4": netcdf4_backend, - "h5netcdf": h5netcdf_backend, - "scipy": scipy_backend, - "pseudonetcdf": pseudonetcdf_backend, - "zarr": zarr_backend, - "cfgrib": cfgrib_backend, - "pydap": pydap_backend, - "pynio": pynio_backend, -} +from .common import BACKEND_ENTRYPOINTS def remove_duplicates(backend_entrypoints): diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index d9128d1d503..c2bfd519bed 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -3,11 +3,24 @@ from ..core import indexing from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, +) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock from .store import open_backend_dataset_store +try: + from PseudoNetCDF import pncopen + + has_pseudonetcdf = True +except ModuleNotFoundError: + has_pseudonetcdf = False + + # psuedonetcdf can invoke netCDF libraries internally PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) @@ -40,7 +53,6 @@ class PseudoNetCDFDataStore(AbstractDataStore): @classmethod def open(cls, filename, lock=None, mode=None, **format_kwargs): - from PseudoNetCDF import pncopen keywords = {"kwargs": format_kwargs} # only include mode if explicitly passed @@ -138,3 +150,7 @@ def open_backend_dataset_pseudonetcdf( open_dataset=open_backend_dataset_pseudonetcdf, open_dataset_parameters=open_dataset_parameters, ) + + +if has_pseudonetcdf: + BACKEND_ENTRYPOINTS["pseudonetcdf"] = pseudonetcdf_backend diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 4995045a739..c5ce943a10a 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -4,9 +4,22 @@ from ..core.pycompat import integer_types from ..core.utils import Frozen, FrozenDict, close_on_error, is_dict_like, is_remote_uri from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint, robust_getitem +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, + robust_getitem, +) from .store import open_backend_dataset_store +try: + import pydap.client + + has_pydap = True +except ModuleNotFoundError: + has_pydap = False + class PydapArrayWrapper(BackendArray): def __init__(self, array): @@ -74,7 +87,6 @@ def __init__(self, ds): @classmethod def open(cls, url, session=None): - import pydap.client ds = pydap.client.open_url(url, session=session) return cls(ds) @@ -133,3 +145,7 @@ def open_backend_dataset_pydap( pydap_backend = BackendEntrypoint( open_dataset=open_backend_dataset_pydap, guess_can_open=guess_can_open_pydap ) + + +if has_pydap: + BACKEND_ENTRYPOINTS["pydap"] = pydap_backend diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index dc6c47935e8..261daa69880 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -3,11 +3,24 @@ from ..core import indexing from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, BackendEntrypoint +from .common import ( + BACKEND_ENTRYPOINTS, + AbstractDataStore, + BackendArray, + BackendEntrypoint, +) from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock from .store import open_backend_dataset_store +try: + import Nio + + has_pynio = True +except ModuleNotFoundError: + has_pynio = False + + # PyNIO can invoke netCDF libraries internally # Add a dedicated lock just in case NCL as well isn't thread-safe. NCL_LOCK = SerializableLock() @@ -45,7 +58,6 @@ class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO""" def __init__(self, filename, mode="r", lock=None, **kwargs): - import Nio if lock is None: lock = PYNIO_LOCK @@ -119,3 +131,7 @@ def open_backend_dataset_pynio( pynio_backend = BackendEntrypoint(open_dataset=open_backend_dataset_pynio) + + +if has_pynio: + BACKEND_ENTRYPOINTS["pynio"] = pynio_backend diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index a0500c7e1c2..c689c1e99d7 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -361,6 +361,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc result = result.chunk(chunks, name_prefix=name_prefix, token=token) # Make the file closeable - result._file_obj = manager + result.set_close(manager.close) return result diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 873a91f9c07..df51d07d686 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -6,12 +6,24 @@ from ..core.indexing import NumpyIndexingAdapter from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number from ..core.variable import Variable -from .common import BackendArray, BackendEntrypoint, WritableCFDataStore +from .common import ( + BACKEND_ENTRYPOINTS, + BackendArray, + BackendEntrypoint, + WritableCFDataStore, +) from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name from .store import open_backend_dataset_store +try: + import scipy.io + + has_scipy = True +except ModuleNotFoundError: + has_scipy = False + def _decode_string(s): if isinstance(s, bytes): @@ -61,8 +73,6 @@ def __setitem__(self, key, value): def _open_scipy_netcdf(filename, mode, mmap, version): import gzip - import scipy.io - # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): try: @@ -271,3 +281,7 @@ def open_backend_dataset_scipy( scipy_backend = BackendEntrypoint( open_dataset=open_backend_dataset_scipy, guess_can_open=guess_can_open_scipy ) + + +if has_scipy: + BACKEND_ENTRYPOINTS["scipy"] = scipy_backend diff --git a/xarray/backends/store.py b/xarray/backends/store.py index d314a9c3ca9..66fca0d39c3 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -1,6 +1,6 @@ from .. import conventions from ..core.dataset import Dataset -from .common import AbstractDataStore, BackendEntrypoint +from .common import BACKEND_ENTRYPOINTS, AbstractDataStore, BackendEntrypoint def guess_can_open_store(store_spec): @@ -19,7 +19,6 @@ def open_backend_dataset_store( decode_timedelta=None, ): vars, attrs = store.load() - file_obj = store encoding = store.get_encoding() vars, attrs, coord_names = conventions.decode_cf_variables( @@ -36,7 +35,7 @@ def open_backend_dataset_store( ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.intersection(vars)) - ds._file_obj = file_obj + ds.set_close(store.close) ds.encoding = encoding return ds @@ -45,3 +44,6 @@ def open_backend_dataset_store( store_backend = BackendEntrypoint( open_dataset=open_backend_dataset_store, guess_can_open=guess_can_open_store ) + + +BACKEND_ENTRYPOINTS["store"] = store_backend diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 3b4b3a3d9d5..ceeb23cac9b 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -9,6 +9,7 @@ from ..core.utils import FrozenDict, HiddenKeyDict, close_on_error from ..core.variable import Variable from .common import ( + BACKEND_ENTRYPOINTS, AbstractWritableDataStore, BackendArray, BackendEntrypoint, @@ -16,6 +17,14 @@ ) from .store import open_backend_dataset_store +try: + import zarr + + has_zarr = True +except ModuleNotFoundError: + has_zarr = False + + # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -289,7 +298,6 @@ def open_group( append_dim=None, write_region=None, ): - import zarr # zarr doesn't support pathlib.Path objects yet. zarr-python#601 if isinstance(store, pathlib.Path): @@ -409,7 +417,6 @@ def store( dimension on which the zarray will be appended only needed in append mode """ - import zarr existing_variables = { vn for vn in variables if _encode_variable_name(vn) in self.ds @@ -705,3 +712,7 @@ def open_backend_dataset_zarr( zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr) + + +if has_zarr: + BACKEND_ENTRYPOINTS["zarr"] = zarr_backend diff --git a/xarray/conventions.py b/xarray/conventions.py index bb0b92c77a1..e33ae53b31d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -576,12 +576,12 @@ def decode_cf( vars = obj._variables attrs = obj.attrs extra_coords = set(obj.coords) - file_obj = obj._file_obj + close = obj._close encoding = obj.encoding elif isinstance(obj, AbstractDataStore): vars, attrs = obj.load() extra_coords = set() - file_obj = obj + close = obj.close encoding = obj.get_encoding() else: raise TypeError("can only decode Dataset or DataStore objects") @@ -599,7 +599,7 @@ def decode_cf( ) ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) - ds._file_obj = file_obj + ds.set_close(close) ds.encoding = encoding return ds diff --git a/xarray/core/common.py b/xarray/core/common.py index 283114770cf..a69ba03a7a4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -11,6 +11,7 @@ Iterator, List, Mapping, + Optional, Tuple, TypeVar, Union, @@ -330,7 +331,9 @@ def get_squeeze_dims( class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" - __slots__ = () + _close: Optional[Callable[[], None]] + + __slots__ = ("_close",) _rolling_exp_cls = RollingExp @@ -1263,11 +1266,27 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): return ops.where_method(self, cond, other) + def set_close(self, close: Optional[Callable[[], None]]) -> None: + """Register the function that releases any resources linked to this object. + + This method controls how xarray cleans up resources associated + with this object when the ``.close()`` method is called. It is mostly + intended for backend developers and it is rarely needed by regular + end-users. + + Parameters + ---------- + close : callable + The function that when called like ``close()`` releases + any resources linked to this object. + """ + self._close = close + def close(self: Any) -> None: - """Close any files linked to this object""" - if self._file_obj is not None: - self._file_obj.close() - self._file_obj = None + """Release any resources linked to this object.""" + if self._close is not None: + self._close() + self._close = None def isnull(self, keep_attrs: bool = None): """Test each value in the array for whether it is a missing value. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1517e007a31..a1eda81cf54 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -344,6 +344,7 @@ class DataArray(AbstractArray, DataWithCoords): _cache: Dict[str, Any] _coords: Dict[Any, Variable] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _name: Optional[Hashable] _variable: Variable @@ -351,7 +352,7 @@ class DataArray(AbstractArray, DataWithCoords): __slots__ = ( "_cache", "_coords", - "_file_obj", + "_close", "_indexes", "_name", "_variable", @@ -421,7 +422,7 @@ def __init__( # public interface. self._indexes = indexes - self._file_obj = None + self._close = None def _replace( self, @@ -2247,6 +2248,28 @@ def drop_sel( ds = self._to_temp_dataset().drop_sel(labels, errors=errors) return self._from_temp_dataset(ds) + def drop_isel(self, indexers=None, **indexers_kwargs): + """Drop index positions from this DataArray. + + Parameters + ---------- + indexers : mapping of hashable to Any + Index locations to drop + **indexers_kwargs : {dim: position, ...}, optional + The keyword arguments form of ``dim`` and ``positions`` + + Returns + ------- + dropped : DataArray + + Raises + ------ + IndexError + """ + dataset = self._to_temp_dataset() + dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs) + return self._from_temp_dataset(dataset) + def dropna( self, dim: Hashable, how: str = "any", thresh: int = None ) -> "DataArray": diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 175034d69de..1f12b355357 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -636,6 +636,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): _coord_names: Set[Hashable] _dims: Dict[Hashable, int] _encoding: Optional[Dict[Hashable, Any]] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _variables: Dict[Hashable, Variable] @@ -645,7 +646,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): "_coord_names", "_dims", "_encoding", - "_file_obj", + "_close", "_indexes", "_variables", "__weakref__", @@ -687,7 +688,7 @@ def __init__( ) self._attrs = dict(attrs) if attrs is not None else None - self._file_obj = None + self._close = None self._encoding = None self._variables = variables self._coord_names = coord_names @@ -703,7 +704,7 @@ def load_store(cls, store, decoder=None) -> "Dataset": if decoder: variables, attributes = decoder(variables, attributes) obj = cls(variables, attrs=attributes) - obj._file_obj = store + obj.set_close(store.close) return obj @property @@ -876,7 +877,7 @@ def __dask_postcompute__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postcompute, args @@ -896,7 +897,7 @@ def __dask_postpersist__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postpersist, args @@ -1007,7 +1008,7 @@ def _construct_direct( attrs=None, indexes=None, encoding=None, - file_obj=None, + close=None, ): """Shortcut around __init__ for internal use when we want to skip costly validation @@ -1020,7 +1021,7 @@ def _construct_direct( obj._dims = dims obj._indexes = indexes obj._attrs = attrs - obj._file_obj = file_obj + obj._close = close obj._encoding = encoding return obj @@ -2122,7 +2123,7 @@ def isel( attrs=self._attrs, indexes=indexes, encoding=self._encoding, - file_obj=self._file_obj, + close=self._close, ) def _isel_fancy( @@ -4053,13 +4054,78 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): labels_for_dim = [labels_for_dim] labels_for_dim = np.asarray(labels_for_dim) try: - index = self.indexes[dim] + index = self.get_index(dim) except KeyError: raise ValueError("dimension %r does not have coordinate labels" % dim) new_index = index.drop(labels_for_dim, errors=errors) ds = ds.loc[{dim: new_index}] return ds + def drop_isel(self, indexers=None, **indexers_kwargs): + """Drop index positions from this Dataset. + + Parameters + ---------- + indexers : mapping of hashable to Any + Index locations to drop + **indexers_kwargs : {dim: position, ...}, optional + The keyword arguments form of ``dim`` and ``positions`` + + Returns + ------- + dropped : Dataset + + Raises + ------ + IndexError + + Examples + -------- + >>> data = np.arange(6).reshape(2, 3) + >>> labels = ["a", "b", "c"] + >>> ds = xr.Dataset({"A": (["x", "y"], data), "y": labels}) + >>> ds + + Dimensions: (x: 2, y: 3) + Coordinates: + * y (y) >> ds.drop_isel(y=[0, 2]) + + Dimensions: (x: 2, y: 1) + Coordinates: + * y (y) >> ds.drop_isel(y=1) + + Dimensions: (x: 2, y: 2) + Coordinates: + * y (y) "Dataset": diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 3ead427e22e..afb234029dc 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2327,6 +2327,12 @@ def test_drop_index_labels(self): with pytest.warns(DeprecationWarning): arr.drop([0, 1, 3], dim="y", errors="ignore") + def test_drop_index_positions(self): + arr = DataArray(np.random.randn(2, 3), dims=["x", "y"]) + actual = arr.drop_sel(y=[0, 1]) + expected = arr[:, 2:] + assert_identical(actual, expected) + def test_dropna(self): x = np.random.randn(4, 4) x[::2, 0] = np.nan diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d385ec3a48b..35137c6ef5d 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2371,8 +2371,12 @@ def test_drop_index_labels(self): data.drop(DataArray(["a", "b", "c"]), dim="x", errors="ignore") assert_identical(expected, actual) - with raises_regex(ValueError, "does not have coordinate labels"): - data.drop_sel(y=1) + actual = data.drop_sel(y=[1]) + expected = data.isel(y=[0, 2]) + assert_identical(expected, actual) + + with raises_regex(KeyError, "not found in axis"): + data.drop_sel(x=0) def test_drop_labels_by_keyword(self): data = Dataset( @@ -2410,6 +2414,34 @@ def test_drop_labels_by_keyword(self): with pytest.raises(ValueError): data.drop(dim="x", x="a") + def test_drop_labels_by_position(self): + data = Dataset( + {"A": (["x", "y"], np.random.randn(2, 6)), "x": ["a", "b"], "y": range(6)} + ) + # Basic functionality. + assert len(data.coords["x"]) == 2 + + actual = data.drop_isel(x=0) + expected = data.drop_sel(x="a") + assert_identical(expected, actual) + + actual = data.drop_isel(x=[0]) + expected = data.drop_sel(x=["a"]) + assert_identical(expected, actual) + + actual = data.drop_isel(x=[0, 1]) + expected = data.drop_sel(x=["a", "b"]) + assert_identical(expected, actual) + assert actual.coords["x"].size == 0 + + actual = data.drop_isel(x=[0, 1], y=range(0, 6, 2)) + expected = data.drop_sel(x=["a", "b"], y=range(0, 6, 2)) + assert_identical(expected, actual) + assert actual.coords["x"].size == 0 + + with pytest.raises(KeyError): + data.drop_isel(z=1) + def test_drop_dims(self): data = xr.Dataset( { diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 110ef47209f..38ebce6da1a 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -92,7 +92,7 @@ def test_set_missing_parameters_raise_error(): with pytest.raises(TypeError): plugins.set_missing_parameters({"engine": backend}) - backend = plugins.BackendEntrypoint( + backend = common.BackendEntrypoint( dummy_open_dataset_kwargs, ("filename_or_obj", "decoder") ) plugins.set_missing_parameters({"engine": backend})