From 81f9b94a9c4747640b3c82132a670dd4f73741ad Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 15 Nov 2016 15:17:54 -0800 Subject: [PATCH 01/14] Disable all caching on xarray.Variable This is a follow-up to generalize the changes from #1024: - Caching and copy-on-write behavior has been moved to separate array classes that are explicitly used in `open_dataset` to wrap arrays loaded from disk (if `cache=True`). - Dask specific logic has been removed from the caching/loading logic on `xarray.Variable`. - Pickle no longer caches automatically under any circumstances. Still needs tests for the `cache` argument to `open_dataset`, but everything else seems to be working. --- doc/whats-new.rst | 14 ++++--- xarray/backends/api.py | 32 +++++++++++++-- xarray/backends/h5netcdf_.py | 5 ++- xarray/backends/netCDF4_.py | 6 +-- xarray/backends/pynio_.py | 6 +-- xarray/backends/scipy_.py | 6 +-- xarray/core/common.py | 8 ++++ xarray/core/dataset.py | 21 +++++----- xarray/core/indexing.py | 40 +++++++++++++++++++ xarray/core/utils.py | 7 ++++ xarray/core/variable.py | 63 ++++++++++-------------------- xarray/test/test_backends.py | 69 +++++++++++++++++++++++---------- xarray/test/test_conventions.py | 8 ++-- xarray/test/test_indexing.py | 42 ++++++++++++++++++++ xarray/test/test_utils.py | 9 +++++ 15 files changed, 239 insertions(+), 97 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 31048f333ab..9067235abc1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,12 +25,14 @@ Breaking changes merges will now succeed in cases that previously raised ``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the previous default. -- Pickling an xarray object based on the dask backend, or reading its - :py:meth:`values` property, won't automatically convert the array from dask - to numpy in the original object anymore. - If a dask object is used as a coord of a :py:class:`~xarray.DataArray` or - :py:class:`~xarray.Dataset`, its values are eagerly computed and cached, - but only if it's used to index a dim (e.g. it's used for alignment). +- Pickling an xarray object or reading its :py:attr:`~DataArray.values` + property no longer always caches values in a NumPy array. Caching + of ``.values`` read from netCDF files on disk is still the default when + :py:func:`open_dataset` is called with ``cache=True``. + By `Guido Imperiale `_ and + `Stephan Hoyer `_. +- Coordinates used to index a dimension are now loaded eagerly into + :py:class:`pandas.Index` objects, instead of loading the values lazily. By `Guido Imperiale `_. Deprecations diff --git a/xarray/backends/api.py b/xarray/backends/api.py index ae12b0cc74a..8ae541ede63 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -13,13 +13,16 @@ from .. import backends, conventions from .common import ArrayWriter +from ..core import indexing from ..core.combine import auto_combine from ..core.utils import close_on_error, is_remote_uri +from ..core.variable import Variable, IndexVariable from ..core.pycompat import basestring DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' + def _get_default_engine(path, allow_remote=False): if allow_remote and is_remote_uri(path): # pragma: no cover try: @@ -117,10 +120,20 @@ def check_attr(name, value): check_attr(k, v) +def _protect_dataset_variables_inplace(dataset, cache): + for name, variable in dataset.variables.items(): + if name not in variable.dims: + # no need to protect IndexVariable objects + data = indexing.CopyOnWriteArray(variable._data) + if cache: + data = indexing.MemoryCachedArray(data) + variable._data = data + + def open_dataset(filename_or_obj, group=None, decode_cf=True, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, drop_variables=None): + chunks=None, lock=None, cache=None, drop_variables=None): """Load and decode a dataset from a file or file-like object. Parameters @@ -162,14 +175,22 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, 'netcdf4'. chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask - arrays. This is an experimental feature; see the documentation for more - details. + arrays. ``chunks={}`` loads the dataset with dask using a single + chunk for all arrays. This is an experimental feature; see the + documentation for more details. lock : False, True or threading.Lock, optional If chunks is provided, this argument is passed on to :py:func:`dask.array.from_array`. By default, a per-variable lock is used when reading data from netCDF files with the netcdf4 and h5netcdf engines to avoid issues with concurrent access when using dask's multithreaded backend. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. drop_variables: string or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -190,12 +211,17 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, concat_characters = False decode_coords = False + if cache is None: + cache = chunks is not None + def maybe_decode_store(store, lock=False): ds = conventions.decode_cf( store, mask_and_scale=mask_and_scale, decode_times=decode_times, concat_characters=concat_characters, decode_coords=decode_coords, drop_variables=drop_variables) + _protect_dataset_variables_inplace(ds, cache) + if chunks is not None: try: from dask.base import tokenize diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 8796e276994..38eee9c23d6 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -5,7 +5,8 @@ from .. import Variable from ..core import indexing -from ..core.utils import FrozenOrderedDict, close_on_error, Frozen +from ..core.utils import (FrozenOrderedDict, close_on_error, Frozen, + NoPickleMixin) from ..core.pycompat import iteritems, bytes_type, unicode_type, OrderedDict from .common import WritableCFDataStore @@ -37,7 +38,7 @@ def _read_attributes(h5netcdf_var): lsd_okay=False, backend='h5netcdf') -class H5NetCDFStore(WritableCFDataStore): +class H5NetCDFStore(WritableCFDataStore, NoPickleMixin): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index b0acb31ed45..70760030298 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -9,7 +9,7 @@ from .. import Variable from ..conventions import pop_to, cf_encoder from ..core import indexing -from ..core.utils import (FrozenOrderedDict, NDArrayMixin, +from ..core.utils import (FrozenOrderedDict, NDArrayMixin, NoPickleMixin, close_on_error, is_remote_uri) from ..core.pycompat import iteritems, basestring, OrderedDict, PY3 @@ -25,7 +25,7 @@ '|': 'native'} -class BaseNetCDF4Array(NDArrayMixin): +class BaseNetCDF4Array(NDArrayMixin, NoPickleMixin): def __init__(self, array, is_remote=False): self.array = array self.is_remote = is_remote @@ -176,7 +176,7 @@ def _extract_nc4_encoding(variable, raise_on_invalid=False, lsd_okay=True, return encoding -class NetCDF4DataStore(WritableCFDataStore): +class NetCDF4DataStore(WritableCFDataStore, NoPickleMixin): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 7ea7f21b651..c545065df93 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -4,13 +4,13 @@ import numpy as np from .. import Variable -from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin +from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin, NoPickleMixin from ..core import indexing from .common import AbstractDataStore -class NioArrayWrapper(NDArrayMixin): +class NioArrayWrapper(NDArrayMixin, NoPickleMixin): def __init__(self, array, ds): self.array = array self._ds = ds # make an explicit reference because pynio uses weakrefs @@ -25,7 +25,7 @@ def __getitem__(self, key): return self.array[key] -class NioDataStore(AbstractDataStore): +class NioDataStore(AbstractDataStore, NoPickleMixin): """Store for accessing datasets via PyNIO """ def __init__(self, filename, mode='r'): diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 200834d2f2c..4aba4366a5c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -8,7 +8,7 @@ from .. import Variable from ..core.pycompat import iteritems, basestring, OrderedDict -from ..core.utils import Frozen, FrozenOrderedDict +from ..core.utils import Frozen, FrozenOrderedDict, NoPickleMixin from ..core.indexing import NumpyIndexingAdapter from .common import WritableCFDataStore @@ -29,7 +29,7 @@ def _decode_attrs(d): for (k, v) in iteritems(d)) -class ScipyArrayWrapper(NumpyIndexingAdapter): +class ScipyArrayWrapper(NumpyIndexingAdapter, NoPickleMixin): def __init__(self, netcdf_file, variable_name): self.netcdf_file = netcdf_file self.variable_name = variable_name @@ -57,7 +57,7 @@ def __getitem__(self, key): return data -class ScipyDataStore(WritableCFDataStore): +class ScipyDataStore(WritableCFDataStore, NoPickleMixin): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a diff --git a/xarray/core/common.py b/xarray/core/common.py index 5ac9994ee8c..7afa20a6864 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -248,6 +248,14 @@ def __dir__(self): if isinstance(item, basestring)] return sorted(set(dir(type(self)) + extra_attrs)) + def __getstate__(self): + """Get this object's state for pickling""" + # we need a custom method to avoid + + # self.__dict__ is the default pickle object, we don't need to + # implement our own __setstate__ method to make pickle work + return self.__dict__ + class SharedMethodsMixin(object): """Shared methods for Dataset, DataArray and Variable.""" diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 282409cf4ca..3ea04495829 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -260,17 +260,12 @@ def load_store(cls, store, decoder=None): return obj def __getstate__(self): - """Load data in-memory before pickling (except for Dask data)""" - for v in self.variables.values(): - if not isinstance(v.data, dask_array_type): - v.load() + """Get this object's state for pickling""" + # we need a custom method to avoid # self.__dict__ is the default pickle object, we don't need to # implement our own __setstate__ method to make pickle work - state = self.__dict__.copy() - # throw away any references to datastores in the pickle - state['_file_obj'] = None - return state + return self.__dict__ @property def variables(self): @@ -331,9 +326,8 @@ def load(self): working with many file objects on disk. """ # access .data to coerce everything to numpy or dask arrays - all_data = dict((k, v.data) for k, v in self.variables.items()) - lazy_data = dict((k, v) for k, v in all_data.items() - if isinstance(v, dask_array_type)) + lazy_data = {k: v._data for k, v in self.variables.items() + if isinstance(v._data, dask_array_type)} if lazy_data: import dask.array as da @@ -343,6 +337,11 @@ def load(self): for k, data in zip(lazy_data, evaluated_data): self.variables[k].data = data + # load everything else sequentially + for k, v in self.variables.items(): + if k not in lazy_data: + v.load() + return self def compute(self): diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 8cbb91ebd4f..6c4d47ade10 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -403,6 +403,46 @@ def __repr__(self): (type(self).__name__, self.array, self.key)) +class CopyOnWriteArray(utils.NDArrayMixin): + def __init__(self, array): + self.array = array + self._copied = False + + def _ensure_copied(self): + if not self._copied: + self.array = np.array(self.array) + self._copied = True + + def __array__(self, dtype=None): + return np.asarray(self.array, dtype=dtype) + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __setitem__(self, key, value): + self._ensure_copied() + self.array[key] = value + + +class MemoryCachedArray(utils.NDArrayMixin): + def __init__(self, array): + self.array = array + + def _ensure_cached(self): + if not isinstance(self.array, np.ndarray): + self.array = np.asarray(self.array) + + def __array__(self, dtype=None): + self._ensure_cached() + return np.asarray(self.array, dtype=dtype) + + def __getitem__(self, key): + return type(self)(self.array[key]) + + def __setitem__(self, key, value): + self.array[key] = value + + def orthogonally_indexable(array): if isinstance(array, np.ndarray): return NumpyIndexingAdapter(array) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ff7ccbc2670..5a701680f4d 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -420,6 +420,13 @@ def __repr__(self): return '%s(array=%r)' % (type(self).__name__, self.array) +class NoPickleMixin(object): + def __getstate__(self): + raise TypeError( + 'cannot pickle objects of type %r: call .compute() or .load() ' + 'to load data into memory first.' % type(self)) + + @contextlib.contextmanager def close_on_error(f): """Context manager to ensure that a file opened by xarray is closed if an diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1b6f5b55dda..47746e07cad 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -142,10 +142,12 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, timedelta): data = np.timedelta64(getattr(data, 'value', data), 'ns') - if (not hasattr(data, 'dtype') or not hasattr(data, 'shape') or + if (not hasattr(data, 'dtype') or not isinstance(data.dtype, np.dtype) or + not hasattr(data, 'shape') or isinstance(data, (np.string_, np.unicode_, np.datetime64, np.timedelta64))): # data must be ndarray-like + # don't allow non-numpy dtypes (e.g., categories) data = np.asarray(data) # we don't want nested self-described arrays @@ -260,7 +262,9 @@ def nbytes(self): @property def _in_memory(self): - return isinstance(self._data, (np.ndarray, PandasIndexAdapter)) + return (isinstance(self._data, (np.ndarray, PandasIndexAdapter)) or + (isinstance(self._data, indexing.MemoryCachedArray) and + isinstance(self._data.array, np.ndarray))) @property def data(self): @@ -277,22 +281,6 @@ def data(self, data): "replacement data must match the Variable's shape") self._data = data - def _data_cast(self): - if isinstance(self._data, (np.ndarray, PandasIndexAdapter)): - return self._data - else: - return np.asarray(self._data) - - def _data_cached(self): - """Load data into memory and return it. - Do not cache dask arrays automatically; that should - require an explicit load() call. - """ - new_data = self._data_cast() - if not isinstance(self._data, dask_array_type): - self._data = new_data - return new_data - @property def _indexable_data(self): return orthogonally_indexable(self._data) @@ -305,7 +293,8 @@ def load(self): because all xarray functions should either work on deferred data or load data automatically. """ - self._data = self._data_cast() + if not isinstance(self._data, np.ndarray): + self._data = np.asarray(self._data) return self def compute(self): @@ -320,19 +309,10 @@ def compute(self): new = self.copy(deep=False) return new.load() - def __getstate__(self): - """Always cache data as an in-memory array before pickling - (with the exception of dask backend)""" - if not isinstance(self._data, dask_array_type): - self._data_cached() - # self.__dict__ is the default pickle object, we don't need to - # implement our own __setstate__ method to make pickle work - return self.__dict__ - @property def values(self): """The variable's data as a numpy.ndarray""" - return _as_array_or_item(self._data_cached()) + return _as_array_or_item(self._data) @values.setter def values(self, values): @@ -425,7 +405,7 @@ def __setitem__(self, key, value): 'assign to this variable, you must first load it ' 'into memory explicitly using the .load_data() ' 'method or accessing its .values attribute.') - data = orthogonally_indexable(self._data_cached()) + data = orthogonally_indexable(self._data) data[key] = value @property @@ -462,7 +442,7 @@ def copy(self, deep=True): the new object. Dimensions, attributes and encodings are always copied. """ if (deep and not isinstance(self.data, dask_array_type) - and not isinstance(self._data, PandasIndexAdapter)): + and not isinstance(self._data, PandasIndexAdapter)): # pandas.Index objects are immutable # dask arrays don't have a copy method # https://github.com/blaze/dask/issues/911 @@ -1144,15 +1124,16 @@ def concat(cls, variables, dim='concat_dim', positions=None, raise TypeError('IndexVariable.concat requires that all input ' 'variables be IndexVariable objects') - arrays = [v._data_cached().array for v in variables] + indexes = [v._data.array for v in variables] - if not arrays: + if not indexes: data = [] else: - data = arrays[0].append(arrays[1:]) + data = indexes[0].append(indexes[1:]) if positions is not None: - indices = nputils.inverse_permutation(np.concatenate(positions)) + indices = nputils.inverse_permutation( + np.concatenate(positions)) data = data.take(indices) attrs = OrderedDict(first_var.attrs) @@ -1167,13 +1148,11 @@ def concat(cls, variables, dim='concat_dim', positions=None, def copy(self, deep=True): """Returns a copy of this object. - If `deep=True`, the values array is loaded into memory and copied onto - the new object. Dimensions, attributes and encodings are always copied. + `deep` is ignored since data is stored in the form of pandas.Index, + which is already immutable. Dimensions, attributes and encodings are + always copied. """ - # there is no need to copy the index values here even if deep=True - # since pandas.Index objects are immutable - data = PandasIndexAdapter(self) if deep else self._data - return type(self)(self.dims, data, self._attrs, + return type(self)(self.dims, self._data, self._attrs, self._encoding, fastpath=True) def _data_equals(self, other): @@ -1190,7 +1169,7 @@ def to_index(self): # n.b. creating a new pandas.Index from an old pandas.Index is # basically free as pandas.Index objects are immutable assert self.ndim == 1 - index = self._data_cached().array + index = self._data.array if isinstance(index, pd.MultiIndex): # set default names for multi-index unnamed levels so that # we can safely rename dimension / coordinate later diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 65e0f3d51ac..c9200907c7e 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -163,18 +163,12 @@ def test_dataset_compute(self): # Test Dataset.compute() for k, v in actual.variables.items(): # IndexVariables are eagerly cached - if k in actual.dims: - self.assertTrue(v._in_memory) - else: - self.assertFalse(v._in_memory) + self.assertEqual(v._in_memory, k in actual.dims) computed = actual.compute() for k, v in actual.variables.items(): - if k in actual.dims: - self.assertTrue(v._in_memory) - else: - self.assertFalse(v._in_memory) + self.assertEqual(v._in_memory, k in actual.dims) for v in computed.variables.values(): self.assertTrue(v._in_memory) @@ -282,10 +276,40 @@ def test_orthogonal_indexing(self): actual = on_disk.isel(**indexers) self.assertDatasetAllClose(expected, actual) + +class PickleNotSupportedMixin(object): + + def test_pickle(self): + expected = Dataset({'foo': ('x', [42])}) + with self.roundtrip(expected) as on_disk: + with self.assertRaisesRegexp( + TypeError, 'load data into memory first'): + pickle.dumps(on_disk) + computed_ds = on_disk.compute() + unpickled_ds = pickle.loads(pickle.dumps(computed_ds)) + self.assertDatasetIdentical(expected, computed_ds) + self.assertDatasetIdentical(expected, unpickled_ds) + + with self.assertRaisesRegexp( + TypeError, 'load data into memory first'): + pickle.dumps(on_disk['foo']) + computed_array = on_disk['foo'].compute() + unpickled_array = pickle.loads(pickle.dumps(computed_array)) + self.assertDatasetIdentical(expected['foo'], computed_array) + self.assertDatasetIdentical(expected['foo'], unpickled_array) + + +class PickleSupportedMixin(object): + def test_pickle(self): - on_disk = open_example_dataset('bears.nc') - unpickled = pickle.loads(pickle.dumps(on_disk)) - self.assertDatasetIdentical(on_disk, unpickled) + # this should work for dask arrays, unlike most real data stores + expected = Dataset({'foo': ('x', [42])}) + with self.roundtrip(expected) as roundtripped: + unpickled_ds = pickle.loads(pickle.dumps(roundtripped)) + self.assertDatasetIdentical(expected, unpickled_ds) + + unpickled_array = pickle.loads(pickle.dumps(roundtripped['foo'])) + self.assertDatasetIdentical(expected['foo'], unpickled_array) class CFEncodedDataTest(DatasetIOTestCases): @@ -443,7 +467,7 @@ def create_tmp_file(suffix='.nc'): shutil.rmtree(temp_dir) -class BaseNetCDF4Test(CFEncodedDataTest): +class BaseNetCDF4Test(CFEncodedDataTest, PickleNotSupportedMixin): def test_open_group(self): # Create a netCDF file with a dataset stored within a group with create_tmp_file() as tmp_file: @@ -642,7 +666,7 @@ def test_variable_len_strings(self): @requires_netCDF4 -class NetCDF4DataTest(BaseNetCDF4Test, TestCase): +class NetCDF4DataTest(BaseNetCDF4Test, PickleNotSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -687,7 +711,7 @@ def test_unsorted_index_raises(self): @requires_netCDF4 @requires_dask -class NetCDF4ViaDaskDataTest(NetCDF4DataTest): +class NetCDF4ViaDaskDataTest(NetCDF4DataTest, PickleNotSupportedMixin): @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}): with NetCDF4DataTest.roundtrip( @@ -753,7 +777,8 @@ def test_netcdf3_endianness(self): @requires_netCDF4 -class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase): +class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, + PickleNotSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -771,7 +796,8 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}): @requires_netCDF4 -class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase): +class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, + PickleNotSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -789,7 +815,8 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}): @requires_scipy_or_netCDF4 -class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): +class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, + PickleNotSupportedMixin, TestCase): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed @@ -841,7 +868,7 @@ def test_cross_engine_read_write_netcdf3(self): @requires_h5netcdf @requires_netCDF4 -class H5NetCDFDataTest(BaseNetCDF4Test, TestCase): +class H5NetCDFDataTest(BaseNetCDF4Test, PickleNotSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -892,7 +919,7 @@ def test_read_byte_attrs_as_unicode(self): @requires_dask @requires_scipy @requires_netCDF4 -class DaskTest(TestCase, DatasetIOTestCases): +class DaskTest(TestCase, PickleSupportedMixin, DatasetIOTestCases): @contextlib.contextmanager def create_store(self): yield Dataset() @@ -1041,6 +1068,7 @@ def test_dataarray_compute(self): self.assertTrue(computed._in_memory) self.assertDataArrayAllClose(actual, computed) + @requires_scipy_or_netCDF4 @requires_pydap class PydapTest(TestCase): @@ -1083,7 +1111,8 @@ def test_dask(self): @requires_scipy @requires_pynio -class TestPyNio(CFEncodedDataTest, Only32BitTypes, TestCase): +class TestPyNio(CFEncodedDataTest, PickleNotSupportedMixin, Only32BitTypes, + TestCase): def test_write_store(self): # pynio is read-only for now pass diff --git a/xarray/test/test_conventions.py b/xarray/test/test_conventions.py index 265b2a93b0f..e3c2c2bd0ff 100644 --- a/xarray/test/test_conventions.py +++ b/xarray/test/test_conventions.py @@ -9,7 +9,7 @@ from xarray import conventions, Variable, Dataset, open_dataset from xarray.core import utils, indexing from . import TestCase, requires_netCDF4, unittest -from .test_backends import CFEncodedDataTest +from .test_backends import CFEncodedDataTest, PickleSupportedMixin from xarray.core.pycompat import iteritems from xarray.backends.memory import InMemoryDataStore from xarray.backends.common import WritableCFDataStore @@ -191,11 +191,11 @@ def test_cf_datetime(self): @requires_netCDF4 def test_decode_cf_datetime_overflow(self): - # checks for + # checks for # https://github.com/pydata/pandas/issues/14068 # https://github.com/pydata/xarray/issues/975 - from datetime import datetime + from datetime import datetime units = 'days since 2000-01-01 00:00:00' # date after 2262 and before 1678 @@ -620,7 +620,7 @@ def null_wrap(ds): @requires_netCDF4 -class TestCFEncodedDataStore(CFEncodedDataTest, TestCase): +class TestCFEncodedDataStore(CFEncodedDataTest, PickleSupportedMixin, TestCase): @contextlib.contextmanager def create_store(self): yield CFEncodedInMemoryStore() diff --git a/xarray/test/test_indexing.py b/xarray/test/test_indexing.py index ed61c9eb463..9d22b4f2c87 100644 --- a/xarray/test/test_indexing.py +++ b/xarray/test/test_indexing.py @@ -200,3 +200,45 @@ def test_lazily_indexed_array(self): actual = lazy[i][j] self.assertEqual(expected.shape, actual.shape) self.assertArrayEqual(expected, actual) + + +class TestCopyOnWriteArray(TestCase): + def test_setitem(self): + original = np.arange(10) + wrapped = indexing.CopyOnWriteArray(original) + wrapped[:] = 0 + self.assertArrayEqual(original, np.arange(10)) + self.assertArrayEqual(wrapped, np.zeros(10)) + + def test_sub_array(self): + original = np.arange(10) + wrapped = indexing.CopyOnWriteArray(original) + child = wrapped[:5] + self.assertIsInstance(child, indexing.CopyOnWriteArray) + child[:] = 0 + self.assertArrayEqual(original, np.arange(10)) + self.assertArrayEqual(wrapped, np.arange(10)) + self.assertArrayEqual(child, np.zeros(5)) + + +class TestMemoryCachedArray(TestCase): + def test_wrapper(self): + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + self.assertArrayEqual(wrapped, np.arange(10)) + self.assertIsInstance(wrapped.array, np.ndarray) + + def test_sub_array(self): + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + child = wrapped[:5] + self.assertIsInstance(child, indexing.MemoryCachedArray) + self.assertArrayEqual(child, np.arange(5)) + self.assertIsInstance(child.array, np.ndarray) + self.assertIsInstance(wrapped.array, indexing.LazilyIndexedArray) + + def test_setitem(self): + original = np.arange(10) + wrapped = indexing.MemoryCachedArray(original) + wrapped[:] = 0 + self.assertArrayEqual(original, np.zeros(10)) diff --git a/xarray/test/test_utils.py b/xarray/test/test_utils.py index 611b45f80d1..db0bf2d202f 100644 --- a/xarray/test/test_utils.py +++ b/xarray/test/test_utils.py @@ -1,6 +1,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import pickle +import pytest + import numpy as np import pandas as pd @@ -171,3 +174,9 @@ def test_hashable(self): self.assertTrue(utils.hashable(v)) for v in [[5, 6], ['seven', '8'], {9: 'ten'}]: self.assertFalse(utils.hashable(v)) + + +def test_no_pickle_mixin(): + obj = utils.NoPickleMixin() + with pytest.raises(TypeError): + pickle.dumps(obj) From 49135f20190cd024dbc7148262026659cc36e955 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 16 Nov 2016 18:51:04 -0800 Subject: [PATCH 02/14] Fixes for test failures --- xarray/backends/api.py | 15 +++++++++++---- xarray/core/variable.py | 21 ++++++++++++--------- xarray/test/test_backends.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 13 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 8ae541ede63..96c72569d52 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -127,7 +127,7 @@ def _protect_dataset_variables_inplace(dataset, cache): data = indexing.CopyOnWriteArray(variable._data) if cache: data = indexing.MemoryCachedArray(data) - variable._data = data + variable.data = data def open_dataset(filename_or_obj, group=None, decode_cf=True, @@ -212,7 +212,7 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, decode_coords = False if cache is None: - cache = chunks is not None + cache = chunks is None def maybe_decode_store(store, lock=False): ds = conventions.decode_cf( @@ -300,7 +300,7 @@ def maybe_decode_store(store, lock=False): def open_dataarray(filename_or_obj, group=None, decode_cf=True, mask_and_scale=True, decode_times=True, concat_characters=True, decode_coords=True, engine=None, - chunks=None, lock=None, drop_variables=None): + chunks=None, lock=None, cache=None, drop_variables=None): """ Opens an DataArray from a netCDF file containing a single data variable. @@ -354,6 +354,13 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, used when reading data from netCDF files with the netcdf4 and h5netcdf engines to avoid issues with concurrent access when using dask's multithreaded backend. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. drop_variables: string or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -375,7 +382,7 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, dataset = open_dataset(filename_or_obj, group, decode_cf, mask_and_scale, decode_times, concat_characters, decode_coords, engine, - chunks, lock, drop_variables) + chunks, lock, cache, drop_variables) if len(dataset.data_vars) != 1: raise ValueError('Given file dataset contains more than one data ' diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 47746e07cad..8c4c6e97979 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -441,14 +441,17 @@ def copy(self, deep=True): If `deep=True`, the data array is loaded into memory and copied onto the new object. Dimensions, attributes and encodings are always copied. """ - if (deep and not isinstance(self.data, dask_array_type) - and not isinstance(self._data, PandasIndexAdapter)): - # pandas.Index objects are immutable - # dask arrays don't have a copy method - # https://github.com/blaze/dask/issues/911 - data = self.data.copy() - else: - data = self._data + data = self._data + + if isinstance(data, indexing.MemoryCachedArray): + # don't share caching between copies + data = indexing.MemoryCachedArray(data.array) + + if deep and not isinstance( + data, (dask_array_type, PandasIndexAdapter)): + # pandas.Index and dask.array objects are immutable + data = np.array(data) + # note: # dims is already an immutable tuple # attributes and encoding will be copied when the new Array is created @@ -681,7 +684,7 @@ def transpose(self, *dims): if len(dims) == 0: dims = self.dims[::-1] axes = self.get_axis_num(dims) - if len(dims) < 2: # no need to transpose if only one dimension + if len(dims) < 2: # no need to transpose if only one dimension return self.copy(deep=False) data = ops.transpose(self.data, axes) return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index c9200907c7e..dabf88d1b13 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -20,6 +20,7 @@ open_mfdataset, backends, save_mfdataset) from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_encoding +from xarray.core import indexing from xarray.core.pycompat import iteritems, PY3 from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap, @@ -175,6 +176,22 @@ def test_dataset_compute(self): self.assertDatasetAllClose(expected, actual) self.assertDatasetAllClose(expected, computed) + def test_dataset_caching(self): + expected = Dataset({'foo': ('x', [5, 6, 7])}) + with self.roundtrip(expected) as actual: + assert isinstance(actual.foo.variable._data, + indexing.MemoryCachedArray) + assert not actual.foo.variable._in_memory + actual.foo.values # cache + assert actual.foo.variable._in_memory + + with self.roundtrip(expected, open_kwargs={'cache': False}) as actual: + assert isinstance(actual.foo.variable._data, + indexing.CopyOnWriteArray) + assert not actual.foo.variable._in_memory + actual.foo.values # no caching + assert not actual.foo.variable._in_memory + def test_roundtrip_None_variable(self): expected = Dataset({None: (('x', 'y'), [[0, 1], [2, 3]])}) with self.roundtrip(expected) as actual: @@ -723,6 +740,10 @@ def test_unsorted_index_raises(self): # dask first pulls items by block. pass + def test_dataset_caching(self): + # caching behavior differs for dask + pass + @requires_scipy class ScipyInMemoryDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): @@ -939,6 +960,13 @@ def test_write_store(self): # Override method in DatasetIOTestCases - not applicable to dask pass + def test_dataset_caching(self): + expected = Dataset({'foo': ('x', [5, 6, 7])}) + with self.roundtrip(expected) as actual: + assert not actual.foo.variable._in_memory + actual.foo.values # no caching + assert not actual.foo.variable._in_memory + def test_open_mfdataset(self): original = Dataset({'foo': ('x', np.random.randn(10))}) with create_tmp_file() as tmp1: From 7f70d15c6d93e080fa4913c316703edd3fa6854b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 19 Nov 2016 19:12:01 -0800 Subject: [PATCH 03/14] Fix IndexVariable.load --- xarray/core/common.py | 8 -------- xarray/core/dataset.py | 8 -------- xarray/core/variable.py | 4 ++++ xarray/test/test_variable.py | 9 +++++++++ 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 7afa20a6864..5ac9994ee8c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -248,14 +248,6 @@ def __dir__(self): if isinstance(item, basestring)] return sorted(set(dir(type(self)) + extra_attrs)) - def __getstate__(self): - """Get this object's state for pickling""" - # we need a custom method to avoid - - # self.__dict__ is the default pickle object, we don't need to - # implement our own __setstate__ method to make pickle work - return self.__dict__ - class SharedMethodsMixin(object): """Shared methods for Dataset, DataArray and Variable.""" diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3ea04495829..b3bde8f7377 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -259,14 +259,6 @@ def load_store(cls, store, decoder=None): obj._file_obj = store return obj - def __getstate__(self): - """Get this object's state for pickling""" - # we need a custom method to avoid - - # self.__dict__ is the default pickle object, we don't need to - # implement our own __setstate__ method to make pickle work - return self.__dict__ - @property def variables(self): """Frozen dictionary of xarray.Variable objects constituting this diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 8c4c6e97979..97a75019c25 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1087,6 +1087,10 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): if not isinstance(self._data, PandasIndexAdapter): self._data = PandasIndexAdapter(self._data) + def load(self): + # data is already loaded into memory for IndexVariable + return self + @Variable.data.setter def data(self, data): Variable.data.fset(self, data) diff --git a/xarray/test/test_variable.py b/xarray/test/test_variable.py index e79d30eeb6e..88b03f8a214 100644 --- a/xarray/test/test_variable.py +++ b/xarray/test/test_variable.py @@ -449,6 +449,15 @@ def test_multiindex(self): self.assertVariableIdentical(Variable((), ('a', 0)), v[0]) self.assertVariableIdentical(v, v[:]) + def test_load(self): + array = self.cls('x', np.arange(5)) + orig_data = array._data + copied = array.copy(deep=True) + array.load() + assert type(array._data) is type(orig_data) + assert type(copied._data) is type(orig_data) + self.assertVariableIdentical(array, copied) + class TestVariable(TestCase, VariableSubclassTestCases): cls = staticmethod(Variable) From 8d19a16d20e03ede7b7e9953507328d02104001f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 20 Nov 2016 17:48:03 -0800 Subject: [PATCH 04/14] Made DataStores pickle-able --- doc/whats-new.rst | 14 +++++-- xarray/backends/common.py | 19 +++++++++ xarray/backends/h5netcdf_.py | 29 ++++++++------ xarray/backends/netCDF4_.py | 63 +++++++++++++++++------------ xarray/backends/pynio_.py | 33 ++++++++++------ xarray/backends/scipy_.py | 48 ++++++++++++---------- xarray/core/utils.py | 7 ---- xarray/test/__init__.py | 6 +++ xarray/test/test_backends.py | 70 +++++++++------------------------ xarray/test/test_conventions.py | 4 +- xarray/test/test_utils.py | 6 --- 11 files changed, 162 insertions(+), 137 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9067235abc1..fae837f394c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,12 +25,18 @@ Breaking changes merges will now succeed in cases that previously raised ``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the previous default. -- Pickling an xarray object or reading its :py:attr:`~DataArray.values` - property no longer always caches values in a NumPy array. Caching - of ``.values`` read from netCDF files on disk is still the default when - :py:func:`open_dataset` is called with ``cache=True``. +- Reading :py:attr:`~DataArray.values` no longer always caches values in a NumPy + array :issue:`1128`. Caching of ``.values`` on variables read from netCDF + files on disk is still the default when :py:func:`open_dataset` is called with + ``cache=True``. By `Guido Imperiale `_ and `Stephan Hoyer `_. +- Pickling a ``Dataset`` or ``DataArray`` linked to a file on disk no longer + caches its values into memory before pickling :issue:`1128`. Instead, pickle + stores file paths and restores objects by reopening file references. This + enables preliminary, experimental use of xarray for opening files with + `dask.distributed `_. + By `Stephan Hoyer `_. - Coordinates used to index a dimension are now loaded eagerly into :py:class:`pandas.Index` objects, instead of loading the values lazily. By `Guido Imperiale `_. diff --git a/xarray/backends/common.py b/xarray/backends/common.py index bf85930c8df..208611829ee 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -235,3 +235,22 @@ def store(self, variables, attributes, check_encoding_set=frozenset()): cf_variables, cf_attrs = cf_encoder(variables, attributes) AbstractWritableDataStore.store(self, cf_variables, cf_attrs, check_encoding_set) + + +class DataStorePickleMixin(object): + """Subclasses must define `ds`, `_opener` and `_mode` attributes. + + Do not subclass this class: it is not part of xarray's external API. + """ + + def __getstate__(self): + state = self.__dict__.copy() + del state['ds'] + if self._mode == 'w': + # file has already been created, don't override when restoring + state['_mode'] = 'a' + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.ds = self._opener(mode=self._mode) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 38eee9c23d6..e3fece22655 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -5,11 +5,10 @@ from .. import Variable from ..core import indexing -from ..core.utils import (FrozenOrderedDict, close_on_error, Frozen, - NoPickleMixin) +from ..core.utils import (FrozenOrderedDict, close_on_error, Frozen) from ..core.pycompat import iteritems, bytes_type, unicode_type, OrderedDict -from .common import WritableCFDataStore +from .common import WritableCFDataStore, DataStorePickleMixin from .netCDF4_ import (_nc4_group, _nc4_values_and_dtype, _extract_nc4_encoding, BaseNetCDF4Array) @@ -38,24 +37,32 @@ def _read_attributes(h5netcdf_var): lsd_okay=False, backend='h5netcdf') -class H5NetCDFStore(WritableCFDataStore, NoPickleMixin): +def _open_h5netcdf_group(filename, mode, group): + import h5netcdf.legacyapi + ds = h5netcdf.legacyapi.Dataset(filename, mode=mode) + with close_on_error(ds): + return _nc4_group(ds, group, mode) + + +class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): """Store for reading and writing data via h5netcdf """ def __init__(self, filename, mode='r', format=None, group=None, writer=None): - import h5netcdf.legacyapi if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') - ds = h5netcdf.legacyapi.Dataset(filename, mode=mode) - with close_on_error(ds): - self.ds = _nc4_group(ds, group, mode) + opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, + group=group) + self.ds = opener() self.format = format + self._opener = opener self._filename = filename + self._mode = mode super(H5NetCDFStore, self).__init__(writer) - def open_store_variable(self, var): + def open_store_variable(self, name, var): dimensions = var.dimensions - data = indexing.LazilyIndexedArray(BaseNetCDF4Array(var)) + data = indexing.LazilyIndexedArray(BaseNetCDF4Array(name, self)) attrs = _read_attributes(var) # netCDF4 specific encoding @@ -70,7 +77,7 @@ def open_store_variable(self, var): return Variable(dimensions, data, attrs, encoding) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(v)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) for k, v in iteritems(self.ds.variables)) def get_attrs(self): diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 70760030298..9b05a41925d 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -1,19 +1,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import operator -from functools import partial import numpy as np from .. import Variable from ..conventions import pop_to, cf_encoder from ..core import indexing -from ..core.utils import (FrozenOrderedDict, NDArrayMixin, NoPickleMixin, +from ..core.utils import (FrozenOrderedDict, NDArrayMixin, close_on_error, is_remote_uri) from ..core.pycompat import iteritems, basestring, OrderedDict, PY3 -from .common import WritableCFDataStore, robust_getitem +from .common import WritableCFDataStore, robust_getitem, DataStorePickleMixin from .netcdf3 import (encode_nc3_attr_value, encode_nc3_variable, maybe_convert_to_char_array) @@ -25,10 +25,14 @@ '|': 'native'} -class BaseNetCDF4Array(NDArrayMixin, NoPickleMixin): - def __init__(self, array, is_remote=False): - self.array = array - self.is_remote = is_remote +class BaseNetCDF4Array(NDArrayMixin): + def __init__(self, variable_name, datastore): + self.datastore = datastore + self.variable_name = variable_name + + @property + def array(self): + return self.datastore.ds.variables[self.variable_name] @property def dtype(self): @@ -42,13 +46,9 @@ def dtype(self): class NetCDF4ArrayWrapper(BaseNetCDF4Array): - def __init__(self, array, is_remote=False): - self.array = array - self.is_remote = is_remote - def __getitem__(self, key): - if self.is_remote: # pragma: no cover - getitem = partial(robust_getitem, catch=RuntimeError) + if self.datastore.is_remote: # pragma: no cover + getitem = functools.partial(robust_getitem, catch=RuntimeError) else: getitem = operator.getitem @@ -176,31 +176,44 @@ def _extract_nc4_encoding(variable, raise_on_invalid=False, lsd_okay=True, return encoding -class NetCDF4DataStore(WritableCFDataStore, NoPickleMixin): +def _open_netcdf4_group(filename, mode, group=None, **kwargs): + import netCDF4 as nc4 + + ds = nc4.Dataset(filename, mode=mode, **kwargs) + + with close_on_error(ds): + ds = _nc4_group(ds, group, mode) + + for var in ds.variables.values(): + # we handle masking and scaling ourselves + var.set_auto_maskandscale(False) + return ds + + +class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ def __init__(self, filename, mode='r', format='NETCDF4', group=None, writer=None, clobber=True, diskless=False, persist=False): - import netCDF4 as nc4 if format is None: format = 'NETCDF4' - ds = nc4.Dataset(filename, mode=mode, clobber=clobber, - diskless=diskless, persist=persist, - format=format) - with close_on_error(ds): - self.ds = _nc4_group(ds, group, mode) + opener = functools.partial(_open_netcdf4_group, filename, mode=mode, + group=group, clobber=clobber, + diskless=diskless, persist=persist, + format=format) + self.ds = opener() self.format = format self.is_remote = is_remote_uri(filename) + self._opener = opener self._filename = filename + self._mode = 'a' if mode == 'w' else mode super(NetCDF4DataStore, self).__init__(writer) - def open_store_variable(self, var): - var.set_auto_maskandscale(False) + def open_store_variable(self, name, var): dimensions = var.dimensions - data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper( - var, self.is_remote)) + data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) attributes = OrderedDict((k, var.getncattr(k)) for k in var.ncattrs()) _ensure_fill_value_valid(data, attributes) @@ -227,7 +240,7 @@ def open_store_variable(self, var): return Variable(dimensions, data, attributes, encoding) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(v)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) for k, v in iteritems(self.ds.variables)) def get_attrs(self): diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index c545065df93..075db5d4ccb 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,19 +1,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +import functools + import numpy as np from .. import Variable -from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin, NoPickleMixin +from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin from ..core import indexing -from .common import AbstractDataStore +from .common import AbstractDataStore, DataStorePickleMixin + +class NioArrayWrapper(NDArrayMixin): -class NioArrayWrapper(NDArrayMixin, NoPickleMixin): - def __init__(self, array, ds): - self.array = array - self._ds = ds # make an explicit reference because pynio uses weakrefs + def __init__(self, variable_name, datastore): + self.datastore = datastore + self.variable_name = variable_name + + @property + def array(self): + return self.datastore.ds.variables[self.variable_name] @property def dtype(self): @@ -25,19 +33,22 @@ def __getitem__(self, key): return self.array[key] -class NioDataStore(AbstractDataStore, NoPickleMixin): +class NioDataStore(AbstractDataStore, DataStorePickleMixin): """Store for accessing datasets via PyNIO """ def __init__(self, filename, mode='r'): import Nio - self.ds = Nio.open_file(filename, mode=mode) + opener = functools.partial(Nio.open_file, filename, mode=mode) + self.ds = opener() + self._opener = opener + self._mode = mode - def open_store_variable(self, var): - data = indexing.LazilyIndexedArray(NioArrayWrapper(var, self.ds)) + def open_store_variable(self, name, var): + data = indexing.LazilyIndexedArray(NioArrayWrapper(name, self)) return Variable(var.dimensions, data, var.attributes) def get_variables(self): - return FrozenOrderedDict((k, self.open_store_variable(v)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) for k, v in self.ds.variables.iteritems()) def get_attrs(self): diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 4aba4366a5c..90cd2192b36 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools from io import BytesIO import numpy as np @@ -8,10 +9,10 @@ from .. import Variable from ..core.pycompat import iteritems, basestring, OrderedDict -from ..core.utils import Frozen, FrozenOrderedDict, NoPickleMixin +from ..core.utils import Frozen, FrozenOrderedDict from ..core.indexing import NumpyIndexingAdapter -from .common import WritableCFDataStore +from .common import WritableCFDataStore, DataStorePickleMixin from .netcdf3 import (is_valid_nc3_name, encode_nc3_attr_value, encode_nc3_variable) @@ -29,18 +30,14 @@ def _decode_attrs(d): for (k, v) in iteritems(d)) -class ScipyArrayWrapper(NumpyIndexingAdapter, NoPickleMixin): - def __init__(self, netcdf_file, variable_name): - self.netcdf_file = netcdf_file +class ScipyArrayWrapper(NumpyIndexingAdapter): + def __init__(self, variable_name, datastore): + self.datastore = datastore self.variable_name = variable_name @property def array(self): - # We can't store the actual netcdf_variable object or its data array, - # because otherwise scipy complains about variables or files still - # referencing mmapped arrays when we try to close datasets without - # having read all data in the file. - return self.netcdf_file.variables[self.variable_name].data + return self.datastore.ds.variables[self.variable_name].data @property def dtype(self): @@ -52,12 +49,12 @@ def __getitem__(self, key): # Copy data if the source file is mmapped. This makes things consistent # with the netCDF4 library by ensuring we can safely read arrays even # after closing associated files. - copy = self.netcdf_file.use_mmap + copy = self.datastore.ds.use_mmap data = np.array(data, dtype=self.dtype, copy=copy) return data -class ScipyDataStore(WritableCFDataStore, NoPickleMixin): +class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a @@ -88,18 +85,22 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - # if filename is a NetCDF3 bytestring we store it in a StringIO - if (isinstance(filename_or_obj, basestring) and - filename_or_obj.startswith('CDF')): - # TODO: this check has the unfortunate side-effect that - # paths to files cannot start with 'CDF'. + if (isinstance(filename_or_obj, bytes) and + filename_or_obj.startswith(b'CDF')): + # it's a NetCDF3 bytestring filename_or_obj = BytesIO(filename_or_obj) - self.ds = scipy.io.netcdf_file( - filename_or_obj, mode=mode, mmap=mmap, version=version) + + opener = functools.partial(scipy.io.netcdf_file, + filename=filename_or_obj, + mode=mode, mmap=mmap, version=version) + self.ds = opener() + self._opener = opener + self._mode = mode + super(ScipyDataStore, self).__init__(writer) def open_store_variable(self, name, var): - return Variable(var.dimensions, ScipyArrayWrapper(self.ds, name), + return Variable(var.dimensions, ScipyArrayWrapper(name, self), _decode_attrs(var._attributes)) def get_variables(self): @@ -154,3 +155,10 @@ def close(self): def __exit__(self, type, value, tb): self.close() + + def __setstate__(self, state): + filename = state['_opener'].keywords['filename'] + if not isinstance(filename, basestring): + # it's a file-like object + filename.seek(0) + super(ScipyDataStore, self).__setstate__(state) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 5a701680f4d..ff7ccbc2670 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -420,13 +420,6 @@ def __repr__(self): return '%s(array=%r)' % (type(self).__name__, self.array) -class NoPickleMixin(object): - def __getstate__(self): - raise TypeError( - 'cannot pickle objects of type %r: call .compute() or .load() ' - 'to load data into memory first.' % type(self)) - - @contextlib.contextmanager def close_on_error(f): """Context manager to ensure that a file opened by xarray is closed if an diff --git a/xarray/test/__init__.py b/xarray/test/__init__.py index e2ac8d9c2ce..6b4823396ae 100644 --- a/xarray/test/__init__.py +++ b/xarray/test/__init__.py @@ -126,6 +126,12 @@ def data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08): return ops.allclose_or_equiv(arr1, arr2, rtol=rtol, atol=atol) +def assert_dataset_allclose(first, second): + # TODO(shoyer): make a lighter weight version of this that doesn't require + # constructing a dummy TestCase. + TestCase.assertDatasetAllClose(TestCase(), first, second) + + class TestCase(unittest.TestCase): if PY3: # Python 3 assertCountEqual is roughly equivalent to Python 2 diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index dabf88d1b13..299aa497af0 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -131,10 +131,7 @@ def assert_loads(vars=None): with self.roundtrip(expected) as actual: for k, v in actual.variables.items(): # IndexVariables are eagerly loaded into memory - if k in actual.dims: - self.assertTrue(v._in_memory) - else: - self.assertFalse(v._in_memory) + self.assertEqual(v._in_memory, k in actual.dims) yield actual for k, v in actual.variables.items(): if k in vars: @@ -176,6 +173,15 @@ def test_dataset_compute(self): self.assertDatasetAllClose(expected, actual) self.assertDatasetAllClose(expected, computed) + def test_pickle(self): + expected = Dataset({'foo': ('x', [42])}) + with self.roundtrip(expected) as roundtripped: + unpickled_ds = pickle.loads(pickle.dumps(roundtripped)) + self.assertDatasetIdentical(expected, unpickled_ds) + + unpickled_array = pickle.loads(pickle.dumps(roundtripped['foo'])) + self.assertDatasetIdentical(expected['foo'], unpickled_array) + def test_dataset_caching(self): expected = Dataset({'foo': ('x', [5, 6, 7])}) with self.roundtrip(expected) as actual: @@ -294,41 +300,6 @@ def test_orthogonal_indexing(self): self.assertDatasetAllClose(expected, actual) -class PickleNotSupportedMixin(object): - - def test_pickle(self): - expected = Dataset({'foo': ('x', [42])}) - with self.roundtrip(expected) as on_disk: - with self.assertRaisesRegexp( - TypeError, 'load data into memory first'): - pickle.dumps(on_disk) - computed_ds = on_disk.compute() - unpickled_ds = pickle.loads(pickle.dumps(computed_ds)) - self.assertDatasetIdentical(expected, computed_ds) - self.assertDatasetIdentical(expected, unpickled_ds) - - with self.assertRaisesRegexp( - TypeError, 'load data into memory first'): - pickle.dumps(on_disk['foo']) - computed_array = on_disk['foo'].compute() - unpickled_array = pickle.loads(pickle.dumps(computed_array)) - self.assertDatasetIdentical(expected['foo'], computed_array) - self.assertDatasetIdentical(expected['foo'], unpickled_array) - - -class PickleSupportedMixin(object): - - def test_pickle(self): - # this should work for dask arrays, unlike most real data stores - expected = Dataset({'foo': ('x', [42])}) - with self.roundtrip(expected) as roundtripped: - unpickled_ds = pickle.loads(pickle.dumps(roundtripped)) - self.assertDatasetIdentical(expected, unpickled_ds) - - unpickled_array = pickle.loads(pickle.dumps(roundtripped['foo'])) - self.assertDatasetIdentical(expected['foo'], unpickled_array) - - class CFEncodedDataTest(DatasetIOTestCases): def test_roundtrip_strings_with_fill_value(self): @@ -484,7 +455,7 @@ def create_tmp_file(suffix='.nc'): shutil.rmtree(temp_dir) -class BaseNetCDF4Test(CFEncodedDataTest, PickleNotSupportedMixin): +class BaseNetCDF4Test(CFEncodedDataTest): def test_open_group(self): # Create a netCDF file with a dataset stored within a group with create_tmp_file() as tmp_file: @@ -683,7 +654,7 @@ def test_variable_len_strings(self): @requires_netCDF4 -class NetCDF4DataTest(BaseNetCDF4Test, PickleNotSupportedMixin, TestCase): +class NetCDF4DataTest(BaseNetCDF4Test, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -728,7 +699,7 @@ def test_unsorted_index_raises(self): @requires_netCDF4 @requires_dask -class NetCDF4ViaDaskDataTest(NetCDF4DataTest, PickleNotSupportedMixin): +class NetCDF4ViaDaskDataTest(NetCDF4DataTest): @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}): with NetCDF4DataTest.roundtrip( @@ -798,8 +769,7 @@ def test_netcdf3_endianness(self): @requires_netCDF4 -class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, - PickleNotSupportedMixin, TestCase): +class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -818,7 +788,7 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}): @requires_netCDF4 class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, - PickleNotSupportedMixin, TestCase): + TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -836,8 +806,7 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}): @requires_scipy_or_netCDF4 -class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, - PickleNotSupportedMixin, TestCase): +class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed @@ -889,7 +858,7 @@ def test_cross_engine_read_write_netcdf3(self): @requires_h5netcdf @requires_netCDF4 -class H5NetCDFDataTest(BaseNetCDF4Test, PickleNotSupportedMixin, TestCase): +class H5NetCDFDataTest(BaseNetCDF4Test, TestCase): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -940,7 +909,7 @@ def test_read_byte_attrs_as_unicode(self): @requires_dask @requires_scipy @requires_netCDF4 -class DaskTest(TestCase, PickleSupportedMixin, DatasetIOTestCases): +class DaskTest(TestCase, DatasetIOTestCases): @contextlib.contextmanager def create_store(self): yield Dataset() @@ -1139,8 +1108,7 @@ def test_dask(self): @requires_scipy @requires_pynio -class TestPyNio(CFEncodedDataTest, PickleNotSupportedMixin, Only32BitTypes, - TestCase): +class TestPyNio(CFEncodedDataTest, Only32BitTypes, TestCase): def test_write_store(self): # pynio is read-only for now pass diff --git a/xarray/test/test_conventions.py b/xarray/test/test_conventions.py index e3c2c2bd0ff..9821369e18e 100644 --- a/xarray/test/test_conventions.py +++ b/xarray/test/test_conventions.py @@ -9,7 +9,7 @@ from xarray import conventions, Variable, Dataset, open_dataset from xarray.core import utils, indexing from . import TestCase, requires_netCDF4, unittest -from .test_backends import CFEncodedDataTest, PickleSupportedMixin +from .test_backends import CFEncodedDataTest from xarray.core.pycompat import iteritems from xarray.backends.memory import InMemoryDataStore from xarray.backends.common import WritableCFDataStore @@ -620,7 +620,7 @@ def null_wrap(ds): @requires_netCDF4 -class TestCFEncodedDataStore(CFEncodedDataTest, PickleSupportedMixin, TestCase): +class TestCFEncodedDataStore(CFEncodedDataTest, TestCase): @contextlib.contextmanager def create_store(self): yield CFEncodedInMemoryStore() diff --git a/xarray/test/test_utils.py b/xarray/test/test_utils.py index db0bf2d202f..ded618ddcab 100644 --- a/xarray/test/test_utils.py +++ b/xarray/test/test_utils.py @@ -174,9 +174,3 @@ def test_hashable(self): self.assertTrue(utils.hashable(v)) for v in [[5, 6], ['seven', '8'], {9: 'ten'}]: self.assertFalse(utils.hashable(v)) - - -def test_no_pickle_mixin(): - obj = utils.NoPickleMixin() - with pytest.raises(TypeError): - pickle.dumps(obj) From 5f5ca9e63366789810480684d43e1fd175e6900c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 20 Nov 2016 18:01:37 -0800 Subject: [PATCH 05/14] Add dask.distributed test --- ci/requirements-py27-cdat+pynio.yml | 1 + ci/requirements-py27-netcdf4-dev.yml | 1 + ci/requirements-py27-pydap.yml | 1 + ci/requirements-py35.yml | 1 + xarray/backends/api.py | 1 - xarray/backends/h5netcdf_.py | 4 +++- xarray/backends/netCDF4_.py | 3 ++- xarray/backends/pynio_.py | 3 ++- xarray/backends/scipy_.py | 6 ++++- xarray/core/utils.py | 8 +++++++ xarray/test/test_distributed.py | 36 ++++++++++++++++++++++++++++ 11 files changed, 60 insertions(+), 5 deletions(-) create mode 100644 xarray/test/test_distributed.py diff --git a/ci/requirements-py27-cdat+pynio.yml b/ci/requirements-py27-cdat+pynio.yml index 53aafb058e1..5f98b9e1f6f 100644 --- a/ci/requirements-py27-cdat+pynio.yml +++ b/ci/requirements-py27-cdat+pynio.yml @@ -5,6 +5,7 @@ dependencies: - python=2.7 - cdat-lite - dask + - distributed - pytest - numpy - pandas>=0.15.0 diff --git a/ci/requirements-py27-netcdf4-dev.yml b/ci/requirements-py27-netcdf4-dev.yml index a64782de235..4ce193a2a82 100644 --- a/ci/requirements-py27-netcdf4-dev.yml +++ b/ci/requirements-py27-netcdf4-dev.yml @@ -3,6 +3,7 @@ dependencies: - python=2.7 - cython - dask + - distributed - h5py - pytest - numpy diff --git a/ci/requirements-py27-pydap.yml b/ci/requirements-py27-pydap.yml index 459f049c76a..e391eee514f 100644 --- a/ci/requirements-py27-pydap.yml +++ b/ci/requirements-py27-pydap.yml @@ -2,6 +2,7 @@ name: test_env dependencies: - python=2.7 - dask + - distributed - h5py - netcdf4 - pytest diff --git a/ci/requirements-py35.yml b/ci/requirements-py35.yml index 0f3b005ea6a..c6641598fca 100644 --- a/ci/requirements-py35.yml +++ b/ci/requirements-py35.yml @@ -3,6 +3,7 @@ dependencies: - python=3.5 - cython - dask + - distributed - h5py - matplotlib - netcdf4 diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 96c72569d52..295837e7eef 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -16,7 +16,6 @@ from ..core import indexing from ..core.combine import auto_combine from ..core.utils import close_on_error, is_remote_uri -from ..core.variable import Variable, IndexVariable from ..core.pycompat import basestring DATAARRAY_NAME = '__xarray_dataarray_name__' diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index e3fece22655..c4ae7bcedcf 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -5,7 +5,8 @@ from .. import Variable from ..core import indexing -from ..core.utils import (FrozenOrderedDict, close_on_error, Frozen) +from ..core.utils import (FrozenOrderedDict, close_on_error, Frozen, + normalize_path) from ..core.pycompat import iteritems, bytes_type, unicode_type, OrderedDict from .common import WritableCFDataStore, DataStorePickleMixin @@ -51,6 +52,7 @@ def __init__(self, filename, mode='r', format=None, group=None, writer=None): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') + filename = normalize_path(filename) opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, group=group) self.ds = opener() diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 9b05a41925d..8b026628124 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -10,7 +10,7 @@ from ..conventions import pop_to, cf_encoder from ..core import indexing from ..core.utils import (FrozenOrderedDict, NDArrayMixin, - close_on_error, is_remote_uri) + close_on_error, is_remote_uri, normalize_path) from ..core.pycompat import iteritems, basestring, OrderedDict, PY3 from .common import WritableCFDataStore, robust_getitem, DataStorePickleMixin @@ -199,6 +199,7 @@ def __init__(self, filename, mode='r', format='NETCDF4', group=None, writer=None, clobber=True, diskless=False, persist=False): if format is None: format = 'NETCDF4' + filename = normalize_path(filename) opener = functools.partial(_open_netcdf4_group, filename, mode=mode, group=group, clobber=clobber, diskless=diskless, persist=persist, diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 075db5d4ccb..0693372d146 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -7,7 +7,7 @@ import numpy as np from .. import Variable -from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin +from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin, normalize_path from ..core import indexing from .common import AbstractDataStore, DataStorePickleMixin @@ -38,6 +38,7 @@ class NioDataStore(AbstractDataStore, DataStorePickleMixin): """ def __init__(self, filename, mode='r'): import Nio + filename = normalize_path(filename) opener = functools.partial(Nio.open_file, filename, mode=mode) self.ds = opener() self._opener = opener diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 90cd2192b36..5d3146c9a16 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -9,7 +9,7 @@ from .. import Variable from ..core.pycompat import iteritems, basestring, OrderedDict -from ..core.utils import Frozen, FrozenOrderedDict +from ..core.utils import Frozen, FrozenOrderedDict, normalize_path from ..core.indexing import NumpyIndexingAdapter from .common import WritableCFDataStore, DataStorePickleMixin @@ -90,6 +90,10 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, # it's a NetCDF3 bytestring filename_or_obj = BytesIO(filename_or_obj) + if isinstance(filename_or_obj, basestring): + # not a file-like object + filename_or_obj = normalize_path(filename_or_obj) + opener = functools.partial(scipy.io.netcdf_file, filename=filename_or_obj, mode=mode, mmap=mmap, version=version) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ff7ccbc2670..d6e1119011c 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -6,6 +6,7 @@ import contextlib import functools import itertools +import os.path import re import warnings from collections import Mapping, MutableMapping, Iterable @@ -436,6 +437,13 @@ def is_remote_uri(path): return bool(re.search('^https?\://', path)) +def normalize_path(path): + if is_remote_uri(path): + return path + else: + return os.path.abspath(os.path.expanduser(path)) + + def is_uniform_spaced(arr, **kwargs): """Return True if values of an array are uniformly spaced and sorted. diff --git a/xarray/test/test_distributed.py b/xarray/test/test_distributed.py new file mode 100644 index 00000000000..a807f72387a --- /dev/null +++ b/xarray/test/test_distributed.py @@ -0,0 +1,36 @@ +import pytest +import xarray as xr +from xarray.core.pycompat import suppress + +distributed = pytest.importorskip('distributed') +da = pytest.importorskip('dask.array') +from distributed.utils_test import cluster, loop + +from xarray.test.test_backends import create_tmp_file +from xarray.test.test_dataset import create_test_data + +from . import assert_dataset_allclose, has_scipy, has_netCDF4, has_h5netcdf + + +ENGINES = [] +if has_scipy: + ENGINES.append('scipy') +if has_netCDF4: + ENGINES.append('netcdf4') +if has_h5netcdf: + ENGINES.append('h5netcdf') + + +@pytest.mark.parametrize('engine', ENGINES) +def test_dask_distributed_integration_test(loop, engine): + with cluster() as (s, _): + with distributed.Client(('127.0.0.1', s['port']), loop=loop): + original = create_test_data() + with create_tmp_file() as filename: + original.to_netcdf(filename, engine=engine) + # TODO: should be able to serialize locks + restored = xr.open_dataset(filename, chunks=3, lock=False, + engine=engine) + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_dataset_allclose(original, computed) From c85dce7d1aaec73661ab493a832c3b1a908d87d2 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 20 Nov 2016 19:51:26 -0800 Subject: [PATCH 06/14] Fix failing Python 2 tests --- xarray/backends/api.py | 24 ++++++++++++++++++++++-- xarray/backends/h5netcdf_.py | 4 +--- xarray/backends/netCDF4_.py | 3 +-- xarray/backends/pynio_.py | 3 +-- xarray/backends/scipy_.py | 27 +++++++++++++++------------ xarray/core/utils.py | 7 ------- xarray/test/__init__.py | 14 ++++++++++---- xarray/test/test_backends.py | 13 +++++++++++-- 8 files changed, 61 insertions(+), 34 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 295837e7eef..bc2afa4b373 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -48,6 +48,13 @@ def _get_default_engine(path, allow_remote=False): return engine +def _normalize_path(path): + if is_remote_uri(path): + return path + else: + return os.path.abspath(os.path.expanduser(path)) + + _global_lock = threading.Lock() @@ -251,6 +258,17 @@ def maybe_decode_store(store, lock=False): if isinstance(filename_or_obj, backends.AbstractDataStore): store = filename_or_obj elif isinstance(filename_or_obj, basestring): + + if (isinstance(filename_or_obj, bytes) and + filename_or_obj.startswith(b'\x89HDF')): + raise ValueError('cannot read netCDF4/HDF5 file images') + elif (isinstance(filename_or_obj, bytes) and + filename_or_obj.startswith(b'CDF')): + # netCDF3 file images are handled by scipy + pass + elif isinstance(filename_or_obj, basestring): + filename_or_obj = _normalize_path(filename_or_obj) + if filename_or_obj.endswith('.gz'): if engine is not None and engine != 'scipy': raise ValueError('can only read gzipped netCDF files with ' @@ -526,8 +544,10 @@ def to_netcdf(dataset, path=None, mode='w', format=None, group=None, raise ValueError('invalid engine for creating bytes with ' 'to_netcdf: %r. Only the default engine ' "or engine='scipy' is supported" % engine) - elif engine is None: - engine = _get_default_engine(path) + else: + if engine is None: + engine = _get_default_engine(path) + path = _normalize_path(path) # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index c4ae7bcedcf..76582cfd72e 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -5,8 +5,7 @@ from .. import Variable from ..core import indexing -from ..core.utils import (FrozenOrderedDict, close_on_error, Frozen, - normalize_path) +from ..core.utils import FrozenOrderedDict, close_on_error, Frozen from ..core.pycompat import iteritems, bytes_type, unicode_type, OrderedDict from .common import WritableCFDataStore, DataStorePickleMixin @@ -52,7 +51,6 @@ def __init__(self, filename, mode='r', format=None, group=None, writer=None): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') - filename = normalize_path(filename) opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, group=group) self.ds = opener() diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 8b026628124..9b05a41925d 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -10,7 +10,7 @@ from ..conventions import pop_to, cf_encoder from ..core import indexing from ..core.utils import (FrozenOrderedDict, NDArrayMixin, - close_on_error, is_remote_uri, normalize_path) + close_on_error, is_remote_uri) from ..core.pycompat import iteritems, basestring, OrderedDict, PY3 from .common import WritableCFDataStore, robust_getitem, DataStorePickleMixin @@ -199,7 +199,6 @@ def __init__(self, filename, mode='r', format='NETCDF4', group=None, writer=None, clobber=True, diskless=False, persist=False): if format is None: format = 'NETCDF4' - filename = normalize_path(filename) opener = functools.partial(_open_netcdf4_group, filename, mode=mode, group=group, clobber=clobber, diskless=diskless, persist=persist, diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 0693372d146..075db5d4ccb 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -7,7 +7,7 @@ import numpy as np from .. import Variable -from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin, normalize_path +from ..core.utils import FrozenOrderedDict, Frozen, NDArrayMixin from ..core import indexing from .common import AbstractDataStore, DataStorePickleMixin @@ -38,7 +38,6 @@ class NioDataStore(AbstractDataStore, DataStorePickleMixin): """ def __init__(self, filename, mode='r'): import Nio - filename = normalize_path(filename) opener = functools.partial(Nio.open_file, filename, mode=mode) self.ds = opener() self._opener = opener diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 5d3146c9a16..0113728f81c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -9,7 +9,7 @@ from .. import Variable from ..core.pycompat import iteritems, basestring, OrderedDict -from ..core.utils import Frozen, FrozenOrderedDict, normalize_path +from ..core.utils import Frozen, FrozenOrderedDict from ..core.indexing import NumpyIndexingAdapter from .common import WritableCFDataStore, DataStorePickleMixin @@ -54,6 +54,17 @@ def __getitem__(self, key): return data +def _open_scipy_netcdf(filename, mode, mmap, version): + import scipy.io + + if isinstance(filename, bytes) and filename.startswith(b'CDF'): + # it's a NetCDF3 bytestring + filename = BytesIO(filename) + + return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, + version=version) + + class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """Store for reading and writing data via scipy.io.netcdf. @@ -85,16 +96,7 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - if (isinstance(filename_or_obj, bytes) and - filename_or_obj.startswith(b'CDF')): - # it's a NetCDF3 bytestring - filename_or_obj = BytesIO(filename_or_obj) - - if isinstance(filename_or_obj, basestring): - # not a file-like object - filename_or_obj = normalize_path(filename_or_obj) - - opener = functools.partial(scipy.io.netcdf_file, + opener = functools.partial(_open_scipy_netcdf, filename=filename_or_obj, mode=mode, mmap=mmap, version=version) self.ds = opener() @@ -162,7 +164,8 @@ def __exit__(self, type, value, tb): def __setstate__(self, state): filename = state['_opener'].keywords['filename'] - if not isinstance(filename, basestring): + if hasattr(filename, 'seek'): # it's a file-like object + # seek to the start of the file so scipy can read it filename.seek(0) super(ScipyDataStore, self).__setstate__(state) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index d6e1119011c..838d34c16bb 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -437,13 +437,6 @@ def is_remote_uri(path): return bool(re.search('^https?\://', path)) -def normalize_path(path): - if is_remote_uri(path): - return path - else: - return os.path.abspath(os.path.expanduser(path)) - - def is_uniform_spaced(arr, **kwargs): """Return True if values of an array are uniformly spaced and sorted. diff --git a/xarray/test/__init__.py b/xarray/test/__init__.py index 6b4823396ae..bdafef7c3ad 100644 --- a/xarray/test/__init__.py +++ b/xarray/test/__init__.py @@ -126,10 +126,16 @@ def data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08): return ops.allclose_or_equiv(arr1, arr2, rtol=rtol, atol=atol) -def assert_dataset_allclose(first, second): - # TODO(shoyer): make a lighter weight version of this that doesn't require - # constructing a dummy TestCase. - TestCase.assertDatasetAllClose(TestCase(), first, second) +def assert_dataset_allclose(d1, d2, rtol=1e-05, atol=1e-08): + assert sorted(d1, key=str) == sorted(d2, key=str) + assert sorted(d1.coords, key=str) == sorted(d2.coords, key=str) + for k in d1: + v1 = d1.variables[k] + v2 = d2.variables[k] + assert v1.dims == v2.dims + allclose = data_allclose_or_equiv( + v1.values, v2.values, rtol=rtol, atol=atol) + assert allclose, (k, v1.values, v2.values) class TestCase(unittest.TestCase): diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 299aa497af0..ca05a5fb7af 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -21,7 +21,7 @@ from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_encoding from xarray.core import indexing -from xarray.core.pycompat import iteritems, PY3 +from xarray.core.pycompat import iteritems, PY2, PY3 from . import (TestCase, requires_scipy, requires_netCDF4, requires_pydap, requires_scipy_or_netCDF4, requires_dask, requires_h5netcdf, @@ -726,9 +726,18 @@ def create_store(self): @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}): serialized = data.to_netcdf(**save_kwargs) - with open_dataset(BytesIO(serialized), **open_kwargs) as ds: + with open_dataset(serialized, engine='scipy', **open_kwargs) as ds: yield ds + def test_bytesio_pickle(self): + if PY2: + raise unittest.SkipTest('cannot pickle BytesIO on Python 2') + data = Dataset({'foo': ('x', [1, 2, 3])}) + fobj = BytesIO(data.to_netcdf()) + with open_dataset(fobj) as ds: + unpickled = pickle.loads(pickle.dumps(ds)) + self.assertDatasetIdentical(unpickled, data) + @requires_scipy class ScipyOnDiskDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): From 95e6737d2667833ea78790ed9261df7ec2e14679 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 20 Nov 2016 20:01:31 -0800 Subject: [PATCH 07/14] Fix failing test on Windows --- xarray/test/test_backends.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index ca05a5fb7af..0f33d21605f 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -179,6 +179,7 @@ def test_pickle(self): unpickled_ds = pickle.loads(pickle.dumps(roundtripped)) self.assertDatasetIdentical(expected, unpickled_ds) + with self.roundtrip(expected) as roundtripped: unpickled_array = pickle.loads(pickle.dumps(roundtripped['foo'])) self.assertDatasetIdentical(expected['foo'], unpickled_array) From 0379bfe33b58b1251b72dd5664c3b99f6b246a62 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 20 Nov 2016 20:09:40 -0800 Subject: [PATCH 08/14] Alternative fix for windows issue --- xarray/test/test_backends.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 0f33d21605f..e81b7715d7f 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -173,12 +173,14 @@ def test_dataset_compute(self): self.assertDatasetAllClose(expected, actual) self.assertDatasetAllClose(expected, computed) - def test_pickle(self): + def test_pickle_dataset(self): expected = Dataset({'foo': ('x', [42])}) with self.roundtrip(expected) as roundtripped: unpickled_ds = pickle.loads(pickle.dumps(roundtripped)) self.assertDatasetIdentical(expected, unpickled_ds) + def test_pickle_dataarray(self): + expected = Dataset({'foo': ('x', [42])}) with self.roundtrip(expected) as roundtripped: unpickled_array = pickle.loads(pickle.dumps(roundtripped['foo'])) self.assertDatasetIdentical(expected['foo'], unpickled_array) From 9a5364b6ed706ae03c40f25f2bb9a56cedb2b9c8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 20 Nov 2016 20:34:23 -0800 Subject: [PATCH 09/14] yet another attempt to fix windows tests --- xarray/test/test_backends.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index e81b7715d7f..4a1ec260e1e 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd +import pytest import xarray as xr from xarray import (Dataset, DataArray, open_dataset, open_dataarray, @@ -173,12 +174,15 @@ def test_dataset_compute(self): self.assertDatasetAllClose(expected, actual) self.assertDatasetAllClose(expected, computed) - def test_pickle_dataset(self): + def test_pickle(self): expected = Dataset({'foo': ('x', [42])}) with self.roundtrip(expected) as roundtripped: unpickled_ds = pickle.loads(pickle.dumps(roundtripped)) - self.assertDatasetIdentical(expected, unpickled_ds) + with unpickled_ds: + self.assertDatasetIdentical(expected, unpickled_ds) + @pytest.mark.skipif(sys.platform == 'win32', + reason='all files on Windows must be closed to delete') def test_pickle_dataarray(self): expected = Dataset({'foo': ('x', [42])}) with self.roundtrip(expected) as roundtripped: @@ -732,9 +736,8 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}): with open_dataset(serialized, engine='scipy', **open_kwargs) as ds: yield ds + @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2') def test_bytesio_pickle(self): - if PY2: - raise unittest.SkipTest('cannot pickle BytesIO on Python 2') data = Dataset({'foo': ('x', [1, 2, 3])}) fobj = BytesIO(data.to_netcdf()) with open_dataset(fobj) as ds: From 85f29cf62770c6aefb17115555c39c9a6154b2b3 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 20 Nov 2016 20:38:12 -0800 Subject: [PATCH 10/14] a different windows fix --- xarray/test/test_backends.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 4a1ec260e1e..198eb65e39b 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -177,9 +177,9 @@ def test_dataset_compute(self): def test_pickle(self): expected = Dataset({'foo': ('x', [42])}) with self.roundtrip(expected) as roundtripped: - unpickled_ds = pickle.loads(pickle.dumps(roundtripped)) - with unpickled_ds: - self.assertDatasetIdentical(expected, unpickled_ds) + raw_pickle = pickle.dumps(roundtripped) + with pickle.loads(raw_pickle) as unpickled_ds: + self.assertDatasetIdentical(expected, unpickled_ds) @pytest.mark.skipif(sys.platform == 'win32', reason='all files on Windows must be closed to delete') From 6fa043d1dcba8796d59082b38463f103eda6fdca Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 21 Nov 2016 08:35:49 -0800 Subject: [PATCH 11/14] yet another attempt to fix test on windows --- xarray/test/test_backends.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 198eb65e39b..935cb74349b 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -178,7 +178,9 @@ def test_pickle(self): expected = Dataset({'foo': ('x', [42])}) with self.roundtrip(expected) as roundtripped: raw_pickle = pickle.dumps(roundtripped) - with pickle.loads(raw_pickle) as unpickled_ds: + # windows doesn't like opening the same file twice + roundtripped.close() + unpickled_ds = pickle.loads(raw_pickle) self.assertDatasetIdentical(expected, unpickled_ds) @pytest.mark.skipif(sys.platform == 'win32', From f804e2e1b0df860d00eb2e1f7ca2632776c5d63b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 21 Nov 2016 09:10:46 -0800 Subject: [PATCH 12/14] another attempt at fixing windows --- xarray/test/test_backends.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 935cb74349b..6dab305387c 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -182,6 +182,7 @@ def test_pickle(self): roundtripped.close() unpickled_ds = pickle.loads(raw_pickle) self.assertDatasetIdentical(expected, unpickled_ds) + unpickled_ds.close() @pytest.mark.skipif(sys.platform == 'win32', reason='all files on Windows must be closed to delete') From e3b9cd68afb127208c584a9e3b8501283ec867cc Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 28 Nov 2016 09:23:51 -0800 Subject: [PATCH 13/14] Skip remaining failing test on windows only --- xarray/test/test_backends.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index 6dab305387c..add9bdbdd37 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -182,7 +182,6 @@ def test_pickle(self): roundtripped.close() unpickled_ds = pickle.loads(raw_pickle) self.assertDatasetIdentical(expected, unpickled_ds) - unpickled_ds.close() @pytest.mark.skipif(sys.platform == 'win32', reason='all files on Windows must be closed to delete') @@ -725,6 +724,12 @@ def test_dataset_caching(self): # caching behavior differs for dask pass + @pytest.mark.skipif( + sys.platform == 'win32', + reason='something related to deleting unclosed files, see GH1128') + def test_pickle(self): + super(NetCDF4ViaDaskDataTest, self).test_pickle() + @requires_scipy class ScipyInMemoryDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): From b102e205fb862bc9d3dca5917686b7d457884eb2 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 30 Nov 2016 08:24:40 -0800 Subject: [PATCH 14/14] Allow file cleanup failures on windows --- xarray/test/test_backends.py | 79 ++++++++++++++++++++------------- xarray/test/test_conventions.py | 3 +- 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/xarray/test/test_backends.py b/xarray/test/test_backends.py index add9bdbdd37..9c074147849 100644 --- a/xarray/test/test_backends.py +++ b/xarray/test/test_backends.py @@ -40,6 +40,9 @@ pass +ON_WINDOWS = sys.platform == 'win32' + + def open_example_dataset(name, *args, **kwargs): return open_dataset(os.path.join(os.path.dirname(__file__), 'data', name), *args, **kwargs) @@ -176,18 +179,18 @@ def test_dataset_compute(self): def test_pickle(self): expected = Dataset({'foo': ('x', [42])}) - with self.roundtrip(expected) as roundtripped: + with self.roundtrip( + expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: raw_pickle = pickle.dumps(roundtripped) # windows doesn't like opening the same file twice roundtripped.close() unpickled_ds = pickle.loads(raw_pickle) self.assertDatasetIdentical(expected, unpickled_ds) - @pytest.mark.skipif(sys.platform == 'win32', - reason='all files on Windows must be closed to delete') def test_pickle_dataarray(self): expected = Dataset({'foo': ('x', [42])}) - with self.roundtrip(expected) as roundtripped: + with self.roundtrip( + expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: unpickled_array = pickle.loads(pickle.dumps(roundtripped['foo'])) self.assertDatasetIdentical(expected['foo'], unpickled_array) @@ -455,13 +458,17 @@ def test_encoding_same_dtype(self): @contextlib.contextmanager -def create_tmp_file(suffix='.nc'): +def create_tmp_file(suffix='.nc', allow_cleanup_failure=False): temp_dir = tempfile.mkdtemp() path = os.path.join(temp_dir, 'temp-%s%s' % (next(_counter), suffix)) try: yield path finally: - shutil.rmtree(temp_dir) + try: + shutil.rmtree(temp_dir) + except OSError: + if not allow_cleanup_failure: + raise class BaseNetCDF4Test(CFEncodedDataTest): @@ -671,8 +678,10 @@ def create_store(self): yield store @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, **save_kwargs) with open_dataset(tmp_file, **open_kwargs) as ds: yield ds @@ -710,9 +719,11 @@ def test_unsorted_index_raises(self): @requires_dask class NetCDF4ViaDaskDataTest(NetCDF4DataTest): @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): with NetCDF4DataTest.roundtrip( - self, data, save_kwargs, open_kwargs) as ds: + self, data, save_kwargs, open_kwargs, + allow_cleanup_failure) as ds: yield ds.chunk() def test_unsorted_index_raises(self): @@ -724,12 +735,6 @@ def test_dataset_caching(self): # caching behavior differs for dask pass - @pytest.mark.skipif( - sys.platform == 'win32', - reason='something related to deleting unclosed files, see GH1128') - def test_pickle(self): - super(NetCDF4ViaDaskDataTest, self).test_pickle() - @requires_scipy class ScipyInMemoryDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): @@ -739,7 +744,8 @@ def create_store(self): yield backends.ScipyDataStore(fobj, 'w') @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): serialized = data.to_netcdf(**save_kwargs) with open_dataset(serialized, engine='scipy', **open_kwargs) as ds: yield ds @@ -762,8 +768,10 @@ def create_store(self): yield store @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, engine='scipy', **save_kwargs) with open_dataset(tmp_file, engine='scipy', **open_kwargs) as ds: yield ds @@ -801,8 +809,10 @@ def create_store(self): yield store @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, format='NETCDF3_CLASSIC', engine='netcdf4', **save_kwargs) with open_dataset(tmp_file, engine='netcdf4', **open_kwargs) as ds: @@ -820,8 +830,10 @@ def create_store(self): yield store @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, format='NETCDF4_CLASSIC', engine='netcdf4', **save_kwargs) with open_dataset(tmp_file, engine='netcdf4', **open_kwargs) as ds: @@ -838,8 +850,10 @@ def test_write_store(self): pass @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, format='netcdf3_64bit', **save_kwargs) with open_dataset(tmp_file, **open_kwargs) as ds: yield ds @@ -888,8 +902,10 @@ def create_store(self): yield backends.H5NetCDFStore(tmp_file, 'w') @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, engine='h5netcdf', **save_kwargs) with open_dataset(tmp_file, engine='h5netcdf', **open_kwargs) as ds: yield ds @@ -938,7 +954,8 @@ def create_store(self): yield Dataset() @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): yield data.chunk() def test_roundtrip_datetime_data(self): @@ -1141,8 +1158,10 @@ def test_orthogonal_indexing(self): pass @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): - with create_tmp_file() as tmp_file: + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as tmp_file: data.to_netcdf(tmp_file, engine='scipy', **save_kwargs) with open_dataset(tmp_file, engine='pynio', **open_kwargs) as ds: yield ds diff --git a/xarray/test/test_conventions.py b/xarray/test/test_conventions.py index 9821369e18e..8adb20dced3 100644 --- a/xarray/test/test_conventions.py +++ b/xarray/test/test_conventions.py @@ -626,7 +626,8 @@ def create_store(self): yield CFEncodedInMemoryStore() @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}): + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): store = CFEncodedInMemoryStore() data.dump_to_store(store, **save_kwargs) yield open_dataset(store, **open_kwargs)