From db0df5d8695b8c521ed287f34a2558f50c898e1e Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Thu, 14 Jan 2021 12:35:17 +0100 Subject: [PATCH 01/17] Move from _file_obj object to _close function --- xarray/backends/api.py | 22 +++++++++++----------- xarray/backends/apiv2.py | 2 +- xarray/backends/rasterio_.py | 2 +- xarray/backends/store.py | 3 +-- xarray/conventions.py | 6 +++--- xarray/core/common.py | 6 +++--- xarray/core/dataarray.py | 6 ++++-- xarray/core/dataset.py | 18 ++++++++++-------- 8 files changed, 34 insertions(+), 31 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index faa7e6cf3d3..eae130f9a28 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -522,7 +522,7 @@ def maybe_decode_store(store, chunks): else: ds2 = ds - ds2._file_obj = ds._file_obj + ds2._close = ds._close return ds2 filename_or_obj = _normalize_path(filename_or_obj) @@ -701,7 +701,7 @@ def open_dataarray( else: (data_array,) = dataset.data_vars.values() - data_array._file_obj = dataset._file_obj + data_array._close = dataset._close # Reset names if they were changed during saving # to ensure that we can 'roundtrip' perfectly @@ -716,14 +716,14 @@ def open_dataarray( class _MultiFileCloser: - __slots__ = ("file_objs",) + __slots__ = ("closers",) - def __init__(self, file_objs): - self.file_objs = file_objs + def __init__(self, closers): + self.closers = closers - def close(self): - for f in self.file_objs: - f.close() + def __call__(self): + for dataset_close in self.closers: + dataset_close() def open_mfdataset( @@ -918,14 +918,14 @@ def open_mfdataset( getattr_ = getattr datasets = [open_(p, **open_kwargs) for p in paths] - file_objs = [getattr_(ds, "_file_obj") for ds in datasets] + closers = [getattr_(ds, "_close") for ds in datasets] if preprocess is not None: datasets = [preprocess(ds) for ds in datasets] if parallel: # calling compute here will return the datasets/file_objs lists, # the underlying datasets will still be stored as dask arrays - datasets, file_objs = dask.compute(datasets, file_objs) + datasets, closers = dask.compute(datasets, closers) # Combine all datasets, closing them in case of a ValueError try: @@ -963,7 +963,7 @@ def open_mfdataset( ds.close() raise - combined._file_obj = _MultiFileCloser(file_objs) + combined._close = _MultiFileCloser(closers) # read global attributes from the attrs_file or from the first dataset if attrs_file is not None: diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index 0f98291983d..a1018abe9fc 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -90,7 +90,7 @@ def _dataset_from_backend_dataset( **extra_tokens, ) - ds._file_obj = backend_ds._file_obj + ds._close = backend_ds._close # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index a0500c7e1c2..3ff568965f3 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -361,6 +361,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc result = result.chunk(chunks, name_prefix=name_prefix, token=token) # Make the file closeable - result._file_obj = manager + result._close = manager.close return result diff --git a/xarray/backends/store.py b/xarray/backends/store.py index d314a9c3ca9..57880db06cf 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -19,7 +19,6 @@ def open_backend_dataset_store( decode_timedelta=None, ): vars, attrs = store.load() - file_obj = store encoding = store.get_encoding() vars, attrs, coord_names = conventions.decode_cf_variables( @@ -36,7 +35,7 @@ def open_backend_dataset_store( ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.intersection(vars)) - ds._file_obj = file_obj + ds._close = store.close ds.encoding = encoding return ds diff --git a/xarray/conventions.py b/xarray/conventions.py index bb0b92c77a1..2fe6dd81aa8 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -576,12 +576,12 @@ def decode_cf( vars = obj._variables attrs = obj.attrs extra_coords = set(obj.coords) - file_obj = obj._file_obj + close = obj._close encoding = obj.encoding elif isinstance(obj, AbstractDataStore): vars, attrs = obj.load() extra_coords = set() - file_obj = obj + close = obj.close encoding = obj.get_encoding() else: raise TypeError("can only decode Dataset or DataStore objects") @@ -599,7 +599,7 @@ def decode_cf( ) ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) - ds._file_obj = file_obj + ds._close = close ds.encoding = encoding return ds diff --git a/xarray/core/common.py b/xarray/core/common.py index 283114770cf..e7f75636af0 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1265,9 +1265,9 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): def close(self: Any) -> None: """Close any files linked to this object""" - if self._file_obj is not None: - self._file_obj.close() - self._file_obj = None + if self._close is not None: + self._close() + self._close = None def isnull(self, keep_attrs: bool = None): """Test each value in the array for whether it is a missing value. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6fdda8fc418..49fa9b80e80 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -344,6 +344,7 @@ class DataArray(AbstractArray, DataWithCoords): _cache: Dict[str, Any] _coords: Dict[Any, Variable] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _name: Optional[Hashable] _variable: Variable @@ -351,7 +352,7 @@ class DataArray(AbstractArray, DataWithCoords): __slots__ = ( "_cache", "_coords", - "_file_obj", + "_close", "_indexes", "_name", "_variable", @@ -376,6 +377,7 @@ def __init__( # internal parameters indexes: Dict[Hashable, pd.Index] = None, fastpath: bool = False, + close: Optional[Callable[[], None]] = None, ): if fastpath: variable = data @@ -421,7 +423,7 @@ def __init__( # public interface. self._indexes = indexes - self._file_obj = None + self._close = close def _replace( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7edc2fab067..6fd291dea53 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -636,6 +636,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): _coord_names: Set[Hashable] _dims: Dict[Hashable, int] _encoding: Optional[Dict[Hashable, Any]] + _close: Optional[Callable[[], None]] _indexes: Optional[Dict[Hashable, pd.Index]] _variables: Dict[Hashable, Variable] @@ -645,7 +646,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): "_coord_names", "_dims", "_encoding", - "_file_obj", + "_close", "_indexes", "_variables", "__weakref__", @@ -664,6 +665,7 @@ def __init__( data_vars: Mapping[Hashable, Any] = None, coords: Mapping[Hashable, Any] = None, attrs: Mapping[Hashable, Any] = None, + close: Optional[Callable[[], None]] = None, ): # TODO(shoyer): expose indexes as a public argument in __init__ @@ -687,7 +689,7 @@ def __init__( ) self._attrs = dict(attrs) if attrs is not None else None - self._file_obj = None + self._close = close self._encoding = None self._variables = variables self._coord_names = coord_names @@ -703,7 +705,7 @@ def load_store(cls, store, decoder=None) -> "Dataset": if decoder: variables, attributes = decoder(variables, attributes) obj = cls(variables, attrs=attributes) - obj._file_obj = store + obj._close = store.close return obj @property @@ -876,7 +878,7 @@ def __dask_postcompute__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postcompute, args @@ -896,7 +898,7 @@ def __dask_postpersist__(self): self._attrs, self._indexes, self._encoding, - self._file_obj, + self._close, ) return self._dask_postpersist, args @@ -1007,7 +1009,7 @@ def _construct_direct( attrs=None, indexes=None, encoding=None, - file_obj=None, + close=None, ): """Shortcut around __init__ for internal use when we want to skip costly validation @@ -1020,7 +1022,7 @@ def _construct_direct( obj._dims = dims obj._indexes = indexes obj._attrs = attrs - obj._file_obj = file_obj + obj._close = close obj._encoding = encoding return obj @@ -2122,7 +2124,7 @@ def isel( attrs=self._attrs, indexes=indexes, encoding=self._encoding, - file_obj=self._file_obj, + close=self._close, ) def _isel_fancy( From cbb89ff8d47708219c0c5c2386d869b312c1b773 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Thu, 14 Jan 2021 14:26:44 +0100 Subject: [PATCH 02/17] Remove all references to _close outside of low level --- xarray/backends/api.py | 3 --- xarray/backends/apiv2.py | 2 -- xarray/backends/rasterio_.py | 11 +++++++---- xarray/backends/store.py | 3 +-- xarray/conventions.py | 3 +-- xarray/core/dataarray.py | 15 ++++++++++++--- xarray/core/dataset.py | 24 +++++++++++++++++++++--- 7 files changed, 42 insertions(+), 19 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index eae130f9a28..66c2aa2f7f2 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -522,7 +522,6 @@ def maybe_decode_store(store, chunks): else: ds2 = ds - ds2._close = ds._close return ds2 filename_or_obj = _normalize_path(filename_or_obj) @@ -701,8 +700,6 @@ def open_dataarray( else: (data_array,) = dataset.data_vars.values() - data_array._close = dataset._close - # Reset names if they were changed during saving # to ensure that we can 'roundtrip' perfectly if DATAARRAY_NAME in dataset.attrs: diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index a1018abe9fc..1d649896e64 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -90,8 +90,6 @@ def _dataset_from_backend_dataset( **extra_tokens, ) - ds._close = backend_ds._close - # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: if isinstance(filename_or_obj, str): diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 3ff568965f3..47df283e55c 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -345,7 +345,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc if cache and chunks is None: data = indexing.MemoryCachedArray(data) - result = DataArray(data=data, dims=("band", "y", "x"), coords=coords, attrs=attrs) + result = DataArray( + data=data, + dims=("band", "y", "x"), + coords=coords, + attrs=attrs, + close=manager.close, + ) if chunks is not None: from dask.base import tokenize @@ -360,7 +366,4 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc name_prefix = "open_rasterio-%s" % token result = result.chunk(chunks, name_prefix=name_prefix, token=token) - # Make the file closeable - result._close = manager.close - return result diff --git a/xarray/backends/store.py b/xarray/backends/store.py index 57880db06cf..281af0839cf 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -33,9 +33,8 @@ def open_backend_dataset_store( decode_timedelta=decode_timedelta, ) - ds = Dataset(vars, attrs=attrs) + ds = Dataset(vars, attrs=attrs, close=store.close) ds = ds.set_coords(coord_names.intersection(vars)) - ds._close = store.close ds.encoding = encoding return ds diff --git a/xarray/conventions.py b/xarray/conventions.py index 2fe6dd81aa8..1a521f56b36 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -597,9 +597,8 @@ def decode_cf( use_cftime=use_cftime, decode_timedelta=decode_timedelta, ) - ds = Dataset(vars, attrs=attrs) + ds = Dataset(vars, attrs=attrs, close=close) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) - ds._close = close ds.encoding = encoding return ds diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 49fa9b80e80..e021b431b6f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -431,6 +431,7 @@ def _replace( coords=None, name: Union[Hashable, None, Default] = _default, indexes=None, + close: Optional[Callable[[], None]] = _default, ) -> "DataArray": if variable is None: variable = self.variable @@ -438,7 +439,11 @@ def _replace( coords = self._coords if name is _default: name = self.name - return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes) + if close is _default: + close = self._close + return type(self)( + variable, coords, name=name, fastpath=True, indexes=indexes, close=close + ) def _replace_maybe_drop_dims( self, variable: Variable, name: Union[Hashable, None, Default] = _default @@ -494,7 +499,8 @@ def _from_temp_dataset( variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables indexes = dataset._indexes - return self._replace(variable, coords, name, indexes=indexes) + close = dataset._close + return self._replace(variable, coords, name, indexes=indexes, close=close) def _to_dataset_split(self, dim: Hashable) -> Dataset: """ splits dataarray along dimension 'dim' """ @@ -538,7 +544,10 @@ def _to_dataset_whole( indexes = self._indexes coord_names = set(self._coords) - dataset = Dataset._construct_direct(variables, coord_names, indexes=indexes) + close = self._close + dataset = Dataset._construct_direct( + variables, coord_names, indexes=indexes, close=close + ) return dataset def to_dataset( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6fd291dea53..bb331659447 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1035,6 +1035,7 @@ def _replace( indexes: Union[Dict[Any, pd.Index], None, Default] = _default, encoding: Union[dict, None, Default] = _default, inplace: bool = False, + close: Optional[Callable[[], None]] = _default, ) -> "Dataset": """Fastpath constructor for internal use. @@ -1057,6 +1058,8 @@ def _replace( self._indexes = indexes if encoding is not _default: self._encoding = encoding + if close is not _default: + self._close = close obj = self else: if variables is None: @@ -1071,8 +1074,10 @@ def _replace( indexes = copy.copy(self._indexes) if encoding is _default: encoding = copy.copy(self._encoding) + if close is _default: + close = self._close obj = self._construct_direct( - variables, coord_names, dims, attrs, indexes, encoding + variables, coord_names, dims, attrs, indexes, encoding, close ) return obj @@ -1332,7 +1337,14 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray": else: indexes = {k: v for k, v in self._indexes.items() if k in coords} - return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) + return DataArray( + variable, + coords, + name=name, + indexes=indexes, + fastpath=True, + close=self._close, + ) def __copy__(self) -> "Dataset": return self.copy(deep=False) @@ -4789,7 +4801,13 @@ def to_array(self, dim="variable", name=None): dims = (dim,) + broadcast_vars[0].dims return DataArray( - data, coords, dims, attrs=self.attrs, name=name, indexes=indexes + data, + coords, + dims, + attrs=self.attrs, + name=name, + indexes=indexes, + close=self._close, ) def _normalize_dim_order( From fed7d90dbc9ade237e5f4428030cd4383bef7215 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Thu, 14 Jan 2021 16:13:23 +0100 Subject: [PATCH 03/17] Fix type hints --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e021b431b6f..cb4ef527152 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -431,7 +431,7 @@ def _replace( coords=None, name: Union[Hashable, None, Default] = _default, indexes=None, - close: Optional[Callable[[], None]] = _default, + close: Union[Callable[[], None], None, Default] = _default, ) -> "DataArray": if variable is None: variable = self.variable diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bb331659447..ec23f62b888 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1035,7 +1035,7 @@ def _replace( indexes: Union[Dict[Any, pd.Index], None, Default] = _default, encoding: Union[dict, None, Default] = _default, inplace: bool = False, - close: Optional[Callable[[], None]] = _default, + close: Union[Callable[[], None], None, Default] = _default, ) -> "Dataset": """Fastpath constructor for internal use. From b2c873992c2d75f66ad8b6b5aac26375204c879e Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Thu, 14 Jan 2021 16:47:55 +0100 Subject: [PATCH 04/17] Cleanup code style --- xarray/core/dataarray.py | 3 ++- xarray/core/dataset.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index cb4ef527152..ac3cff2cd67 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -441,9 +441,10 @@ def _replace( name = self.name if close is _default: close = self._close - return type(self)( + replaced = type(self)( variable, coords, name=name, fastpath=True, indexes=indexes, close=close ) + return replaced def _replace_maybe_drop_dims( self, variable: Variable, name: Union[Hashable, None, Default] = _default diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ec23f62b888..be1227ede6d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1337,7 +1337,7 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray": else: indexes = {k: v for k, v in self._indexes.items() if k in coords} - return DataArray( + da = DataArray( variable, coords, name=name, @@ -1345,6 +1345,7 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray": fastpath=True, close=self._close, ) + return da def __copy__(self) -> "Dataset": return self.copy(deep=False) @@ -4800,7 +4801,7 @@ def to_array(self, dim="variable", name=None): dims = (dim,) + broadcast_vars[0].dims - return DataArray( + da = DataArray( data, coords, dims, @@ -4809,6 +4810,7 @@ def to_array(self, dim="variable", name=None): indexes=indexes, close=self._close, ) + return da def _normalize_dim_order( self, dim_order: List[Hashable] = None From 8b560259092e68dc0325f6dfbeffb9e7885b2972 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Thu, 14 Jan 2021 16:53:30 +0100 Subject: [PATCH 05/17] Fix non-trivial type hint problem --- xarray/core/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/common.py b/xarray/core/common.py index e7f75636af0..e70cde01fda 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -11,6 +11,7 @@ Iterator, List, Mapping, + Optional, Tuple, TypeVar, Union, @@ -1265,6 +1266,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): def close(self: Any) -> None: """Close any files linked to this object""" + self._close: Optional[Callable[[], None]] if self._close is not None: self._close() self._close = None From 4f6b899755e5fede425bb1211dfcb3ac928bc2d7 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Thu, 14 Jan 2021 19:07:27 +0100 Subject: [PATCH 06/17] Revert adding the `close` argument and add a set_close instead --- xarray/backends/rasterio_.py | 2 +- xarray/backends/store.py | 3 ++- xarray/conventions.py | 3 ++- xarray/core/common.py | 8 ++++++-- xarray/core/dataarray.py | 6 +++--- xarray/core/dataset.py | 7 +++---- 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 47df283e55c..1e2d1d31034 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -350,8 +350,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc dims=("band", "y", "x"), coords=coords, attrs=attrs, - close=manager.close, ) + result.set_close(manager.close) if chunks is not None: from dask.base import tokenize diff --git a/xarray/backends/store.py b/xarray/backends/store.py index 281af0839cf..802c78d241d 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -33,7 +33,8 @@ def open_backend_dataset_store( decode_timedelta=decode_timedelta, ) - ds = Dataset(vars, attrs=attrs, close=store.close) + ds = Dataset(vars, attrs=attrs) + ds.set_close(store.close) ds = ds.set_coords(coord_names.intersection(vars)) ds.encoding = encoding diff --git a/xarray/conventions.py b/xarray/conventions.py index 1a521f56b36..76f242c3f13 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -597,7 +597,8 @@ def decode_cf( use_cftime=use_cftime, decode_timedelta=decode_timedelta, ) - ds = Dataset(vars, attrs=attrs, close=close) + ds = Dataset(vars, attrs=attrs) + ds.set_close(close) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) ds.encoding = encoding diff --git a/xarray/core/common.py b/xarray/core/common.py index e70cde01fda..3b3ace6a364 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -331,7 +331,9 @@ def get_squeeze_dims( class DataWithCoords(SupportsArithmetic, AttrAccessMixin): """Shared base class for Dataset and DataArray.""" - __slots__ = () + _close: Optional[Callable[[], None]] + + __slots__ = ("_close",) _rolling_exp_cls = RollingExp @@ -1264,9 +1266,11 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): return ops.where_method(self, cond, other) + def set_close(self, close: Optional[Callable[[], None]]) -> None: + self._close = close + def close(self: Any) -> None: """Close any files linked to this object""" - self._close: Optional[Callable[[], None]] if self._close is not None: self._close() self._close = None diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ac3cff2cd67..6183709d337 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -377,7 +377,6 @@ def __init__( # internal parameters indexes: Dict[Hashable, pd.Index] = None, fastpath: bool = False, - close: Optional[Callable[[], None]] = None, ): if fastpath: variable = data @@ -423,7 +422,7 @@ def __init__( # public interface. self._indexes = indexes - self._close = close + self._close = None def _replace( self, @@ -442,8 +441,9 @@ def _replace( if close is _default: close = self._close replaced = type(self)( - variable, coords, name=name, fastpath=True, indexes=indexes, close=close + variable, coords, name=name, fastpath=True, indexes=indexes ) + replaced.set_close(close) return replaced def _replace_maybe_drop_dims( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index be1227ede6d..a80e5be1f49 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -665,7 +665,6 @@ def __init__( data_vars: Mapping[Hashable, Any] = None, coords: Mapping[Hashable, Any] = None, attrs: Mapping[Hashable, Any] = None, - close: Optional[Callable[[], None]] = None, ): # TODO(shoyer): expose indexes as a public argument in __init__ @@ -689,7 +688,7 @@ def __init__( ) self._attrs = dict(attrs) if attrs is not None else None - self._close = close + self._close = None self._encoding = None self._variables = variables self._coord_names = coord_names @@ -1343,8 +1342,8 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray": name=name, indexes=indexes, fastpath=True, - close=self._close, ) + da.set_close(self._close) return da def __copy__(self) -> "Dataset": @@ -4808,8 +4807,8 @@ def to_array(self, dim="variable", name=None): attrs=self.attrs, name=name, indexes=indexes, - close=self._close, ) + da.set_close(self._close) return da def _normalize_dim_order( From eb6e89c2df8f8a572646eae7b7ec2ce5cd2b6fd0 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Thu, 14 Jan 2021 19:20:26 +0100 Subject: [PATCH 07/17] Remove helper class for an easier helper function + code style --- xarray/backends/api.py | 17 +++++------------ xarray/core/dataset.py | 17 ++--------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 66c2aa2f7f2..2b5c3d255a4 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -712,17 +712,6 @@ def open_dataarray( return data_array -class _MultiFileCloser: - __slots__ = ("closers",) - - def __init__(self, closers): - self.closers = closers - - def __call__(self): - for dataset_close in self.closers: - dataset_close() - - def open_mfdataset( paths, chunks=None, @@ -960,7 +949,11 @@ def open_mfdataset( ds.close() raise - combined._close = _MultiFileCloser(closers) + def multi_file_closer(): + for closer in closers: + closer() + + combined._close = multi_file_closer # read global attributes from the attrs_file or from the first dataset if attrs_file is not None: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a80e5be1f49..cca16dfdfd4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1336,13 +1336,7 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray": else: indexes = {k: v for k, v in self._indexes.items() if k in coords} - da = DataArray( - variable, - coords, - name=name, - indexes=indexes, - fastpath=True, - ) + da = DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) da.set_close(self._close) return da @@ -4800,14 +4794,7 @@ def to_array(self, dim="variable", name=None): dims = (dim,) + broadcast_vars[0].dims - da = DataArray( - data, - coords, - dims, - attrs=self.attrs, - name=name, - indexes=indexes, - ) + da = DataArray(data, coords, dims, attrs=self.attrs, name=name, indexes=indexes) da.set_close(self._close) return da From 552b65247b107c71ff7edeeb1fe8673cce0986be Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Thu, 14 Jan 2021 19:40:29 +0100 Subject: [PATCH 08/17] Add set_close docstring --- xarray/core/common.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 3b3ace6a364..e84b44d7441 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1266,8 +1266,18 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): return ops.where_method(self, cond, other) - def set_close(self, close: Optional[Callable[[], None]]) -> None: - self._close = close + def set_close(self, close_store: Optional[Callable[[], None]]) -> None: + """Register the function that releases all resources used by the data store. + + This method is mostly intended for backend developers and it is rarely + needed by regular end-users. + + Parameters + ---------- + close_store : callable + A callable that releases all resources used by the data store. + """ + self._close = close_store def close(self: Any) -> None: """Close any files linked to this object""" From 947c3b527c0c1700fb4156a2feced9f3e6d51aa5 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Thu, 14 Jan 2021 19:52:38 +0100 Subject: [PATCH 09/17] Code style --- xarray/backends/rasterio_.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 1e2d1d31034..f96a892521a 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -345,12 +345,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc if cache and chunks is None: data = indexing.MemoryCachedArray(data) - result = DataArray( - data=data, - dims=("band", "y", "x"), - coords=coords, - attrs=attrs, - ) + result = DataArray(data=data, dims=("band", "y", "x"), coords=coords, attrs=attrs) result.set_close(manager.close) if chunks is not None: From 3a67acac56ca4d6d6b33f126958c94594853822b Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Fri, 15 Jan 2021 09:28:54 +0100 Subject: [PATCH 10/17] Revert changes in _replace to keep cose as an exception See: https://github.com/pydata/xarray/pull/4809/files#r557628298 --- xarray/backends/api.py | 3 +++ xarray/backends/apiv2.py | 2 ++ xarray/backends/rasterio_.py | 4 +++- xarray/backends/store.py | 2 +- xarray/conventions.py | 2 +- xarray/core/dataarray.py | 17 +++-------------- xarray/core/dataset.py | 7 +------ 7 files changed, 14 insertions(+), 23 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2b5c3d255a4..4698664ea37 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -522,6 +522,7 @@ def maybe_decode_store(store, chunks): else: ds2 = ds + ds2.set_close(ds._close) return ds2 filename_or_obj = _normalize_path(filename_or_obj) @@ -700,6 +701,8 @@ def open_dataarray( else: (data_array,) = dataset.data_vars.values() + data_array.set_close(dataset._close) + # Reset names if they were changed during saving # to ensure that we can 'roundtrip' perfectly if DATAARRAY_NAME in dataset.attrs: diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py index 1d649896e64..d31fc9ea773 100644 --- a/xarray/backends/apiv2.py +++ b/xarray/backends/apiv2.py @@ -90,6 +90,8 @@ def _dataset_from_backend_dataset( **extra_tokens, ) + ds.set_close(backend_ds._close) + # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: if isinstance(filename_or_obj, str): diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index f96a892521a..c689c1e99d7 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -346,7 +346,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc data = indexing.MemoryCachedArray(data) result = DataArray(data=data, dims=("band", "y", "x"), coords=coords, attrs=attrs) - result.set_close(manager.close) if chunks is not None: from dask.base import tokenize @@ -361,4 +360,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc name_prefix = "open_rasterio-%s" % token result = result.chunk(chunks, name_prefix=name_prefix, token=token) + # Make the file closeable + result.set_close(manager.close) + return result diff --git a/xarray/backends/store.py b/xarray/backends/store.py index 802c78d241d..20fa13af202 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -34,8 +34,8 @@ def open_backend_dataset_store( ) ds = Dataset(vars, attrs=attrs) - ds.set_close(store.close) ds = ds.set_coords(coord_names.intersection(vars)) + ds.set_close(store.close) ds.encoding = encoding return ds diff --git a/xarray/conventions.py b/xarray/conventions.py index 76f242c3f13..e33ae53b31d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -598,8 +598,8 @@ def decode_cf( decode_timedelta=decode_timedelta, ) ds = Dataset(vars, attrs=attrs) - ds.set_close(close) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) + ds.set_close(close) ds.encoding = encoding return ds diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6183709d337..e13ea44baad 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -430,7 +430,6 @@ def _replace( coords=None, name: Union[Hashable, None, Default] = _default, indexes=None, - close: Union[Callable[[], None], None, Default] = _default, ) -> "DataArray": if variable is None: variable = self.variable @@ -438,13 +437,7 @@ def _replace( coords = self._coords if name is _default: name = self.name - if close is _default: - close = self._close - replaced = type(self)( - variable, coords, name=name, fastpath=True, indexes=indexes - ) - replaced.set_close(close) - return replaced + return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes) def _replace_maybe_drop_dims( self, variable: Variable, name: Union[Hashable, None, Default] = _default @@ -500,8 +493,7 @@ def _from_temp_dataset( variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables indexes = dataset._indexes - close = dataset._close - return self._replace(variable, coords, name, indexes=indexes, close=close) + return self._replace(variable, coords, name, indexes=indexes) def _to_dataset_split(self, dim: Hashable) -> Dataset: """ splits dataarray along dimension 'dim' """ @@ -545,10 +537,7 @@ def _to_dataset_whole( indexes = self._indexes coord_names = set(self._coords) - close = self._close - dataset = Dataset._construct_direct( - variables, coord_names, indexes=indexes, close=close - ) + dataset = Dataset._construct_direct(variables, coord_names, indexes=indexes) return dataset def to_dataset( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index cca16dfdfd4..e2dcf580600 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1034,7 +1034,6 @@ def _replace( indexes: Union[Dict[Any, pd.Index], None, Default] = _default, encoding: Union[dict, None, Default] = _default, inplace: bool = False, - close: Union[Callable[[], None], None, Default] = _default, ) -> "Dataset": """Fastpath constructor for internal use. @@ -1057,8 +1056,6 @@ def _replace( self._indexes = indexes if encoding is not _default: self._encoding = encoding - if close is not _default: - self._close = close obj = self else: if variables is None: @@ -1073,10 +1070,8 @@ def _replace( indexes = copy.copy(self._indexes) if encoding is _default: encoding = copy.copy(self._encoding) - if close is _default: - close = self._close obj = self._construct_direct( - variables, coord_names, dims, attrs, indexes, encoding, close + variables, coord_names, dims, attrs, indexes, encoding ) return obj From 32976b157a7f1af5061eeac4fa218eada20d3927 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Fri, 15 Jan 2021 09:32:04 +0100 Subject: [PATCH 11/17] One more bit to revert --- xarray/core/dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e2dcf580600..8a6966319cd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1331,9 +1331,7 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray": else: indexes = {k: v for k, v in self._indexes.items() if k in coords} - da = DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) - da.set_close(self._close) - return da + return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) def __copy__(self) -> "Dataset": return self.copy(deep=False) From 82efe0d8b221b2350cd9f2d32e5ed04717728f9f Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Fri, 15 Jan 2021 09:34:20 +0100 Subject: [PATCH 12/17] One more bit to revert --- xarray/core/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 8a6966319cd..4382c254974 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4787,9 +4787,9 @@ def to_array(self, dim="variable", name=None): dims = (dim,) + broadcast_vars[0].dims - da = DataArray(data, coords, dims, attrs=self.attrs, name=name, indexes=indexes) - da.set_close(self._close) - return da + return DataArray( + data, coords, dims, attrs=self.attrs, name=name, indexes=indexes + ) def _normalize_dim_order( self, dim_order: List[Hashable] = None From 675b0fe53dbf7d41d461918d063f9fe686c03371 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Fri, 15 Jan 2021 10:31:23 +0100 Subject: [PATCH 13/17] Add What's New entry --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index db10ec653c5..5465cb7761e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -103,6 +103,8 @@ Internal Changes By `Maximilian Roos `_. - Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn `_. +- Added the ``set_close`` method to ``Dataset`` and ``DataArray`` for beckends to specify how to voluntary release + all resources. (:pull:`#4809`), By `Alessandro Amici `_. .. _whats-new.0.16.2: From 538371fe09d78b4710ee58d8950832d842eecbb5 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Fri, 15 Jan 2021 19:57:12 +0100 Subject: [PATCH 14/17] Use set_close setter --- xarray/backends/api.py | 2 +- xarray/core/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 4698664ea37..30f9532b29a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -956,7 +956,7 @@ def multi_file_closer(): for closer in closers: closer() - combined._close = multi_file_closer + combined.set_close(multi_file_closer) # read global attributes from the attrs_file or from the first dataset if attrs_file is not None: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4382c254974..136edffb202 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -704,7 +704,7 @@ def load_store(cls, store, decoder=None) -> "Dataset": if decoder: variables, attributes = decoder(variables, attributes) obj = cls(variables, attrs=attributes) - obj._close = store.close + obj.set_close(store.close) return obj @property From 65e43bfa28e08cec34da7072732609db6d086ece Mon Sep 17 00:00:00 2001 From: alexamici Date: Fri, 15 Jan 2021 20:04:33 +0100 Subject: [PATCH 15/17] Apply suggestions from code review Co-authored-by: Stephan Hoyer --- xarray/core/common.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index e84b44d7441..e2b41d6c0e0 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1269,13 +1269,16 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): def set_close(self, close_store: Optional[Callable[[], None]]) -> None: """Register the function that releases all resources used by the data store. - This method is mostly intended for backend developers and it is rarely - needed by regular end-users. + This method controls how xarray cleans up resources associated + with this object when the .close() method is called. It is mostly + intended for backend developers and it is rarely needed by regular + end-users. Parameters ---------- close_store : callable - A callable that releases all resources used by the data store. + A callable that when called like ``close_store()`` releases + all resources used by the data store. """ self._close = close_store From a8f96b69e60259b9d4bef11f97459dea0226e376 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Fri, 15 Jan 2021 20:05:47 +0100 Subject: [PATCH 16/17] Rename user-visible argument --- xarray/core/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index e2b41d6c0e0..402ae545e5f 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1266,7 +1266,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): return ops.where_method(self, cond, other) - def set_close(self, close_store: Optional[Callable[[], None]]) -> None: + def set_close(self, close: Optional[Callable[[], None]]) -> None: """Register the function that releases all resources used by the data store. This method controls how xarray cleans up resources associated @@ -1276,11 +1276,11 @@ def set_close(self, close_store: Optional[Callable[[], None]]) -> None: Parameters ---------- - close_store : callable - A callable that when called like ``close_store()`` releases + close : callable + A callable that when called like ``close()`` releases all resources used by the data store. """ - self._close = close_store + self._close = close def close(self: Any) -> None: """Close any files linked to this object""" From 5f047a6de9535317ee982dda60ff0c859bb48435 Mon Sep 17 00:00:00 2001 From: Alessandro Amici Date: Fri, 15 Jan 2021 20:43:57 +0100 Subject: [PATCH 17/17] Sync wording in docstrings. --- xarray/core/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 402ae545e5f..a69ba03a7a4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1267,23 +1267,23 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): return ops.where_method(self, cond, other) def set_close(self, close: Optional[Callable[[], None]]) -> None: - """Register the function that releases all resources used by the data store. + """Register the function that releases any resources linked to this object. This method controls how xarray cleans up resources associated - with this object when the .close() method is called. It is mostly + with this object when the ``.close()`` method is called. It is mostly intended for backend developers and it is rarely needed by regular end-users. Parameters ---------- close : callable - A callable that when called like ``close()`` releases - all resources used by the data store. + The function that when called like ``close()`` releases + any resources linked to this object. """ self._close = close def close(self: Any) -> None: - """Close any files linked to this object""" + """Release any resources linked to this object.""" if self._close is not None: self._close() self._close = None