Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix to_netcdf append bug (GH1215) #1609

Merged
merged 19 commits into from
Oct 25, 2017
Merged
8 changes: 6 additions & 2 deletions doc/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ for dealing with datasets too big to fit into memory. Instead, xarray integrates
with dask.array (see :ref:`dask`), which provides a fully featured engine for
streaming computation.

It is possible to append or overwrite netCDF variables using the ``mode='a'``
argument. When using this option, all variables in the dataset will be written
to the original netCDF file, regardless if they exist in the original dataset.

.. _io.encoding:

Reading encoded data
Expand Down Expand Up @@ -390,7 +394,7 @@ over the network until we look at particular values:

Some servers require authentication before we can access the data. For this
purpose we can explicitly create a :py:class:`~xarray.backends.PydapDataStore`
and pass in a `Requests`__ session object. For example for
and pass in a `Requests`__ session object. For example for
HTTP Basic authentication::

import xarray as xr
Expand All @@ -403,7 +407,7 @@ HTTP Basic authentication::
session=session)
ds = xr.open_dataset(store)

`Pydap's cas module`__ has functions that generate custom sessions for
`Pydap's cas module`__ has functions that generate custom sessions for
servers that use CAS single sign-on. For example, to connect to servers
that require NASA's URS authentication::

Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ Bug fixes
the first argument was a numpy variable (:issue:`1588`).
By `Guido Imperiale <https://github.com/crusaderky>`_.

- Fix bug in :py:meth:`~xarray.Dataset.to_netcdf` when writing in append mode
(:issue:`1215`).
By `Joe Hamman <https://github.com/jhamman>`_.

- Fix ``netCDF4`` backend to properly roundtrip the ``shuffle`` encoding option
(:issue:`1606`).
By `Joe Hamman <https://github.com/jhamman>`_.
Expand Down
8 changes: 6 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,12 @@ def set_variables(self, variables, check_encoding_set,
for vn, v in iteritems(variables):
name = _encode_variable_name(vn)
check = vn in check_encoding_set
target, source = self.prepare_variable(
name, v, check, unlimited_dims=unlimited_dims)
if vn not in self.variables:
target, source = self.prepare_variable(
name, v, check, unlimited_dims=unlimited_dims)
else:
target, source = self.ds.variables[name], v.data

self.writer.add(source, target)

def set_necessary_dimensions(self, variable, unlimited_dims=None):
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,8 @@ def to_netcdf(self, path=None, mode='w', format=None, group=None,
default format becomes NETCDF3_64BIT).
mode : {'w', 'a'}, optional
Write ('w') or append ('a') mode. If mode='w', any existing file at
this location will be overwritten.
this location will be overwritten. If mode='a', existing variables
will be overwritten.
format : {'NETCDF4', 'NETCDF4_CLASSIC', 'NETCDF3_64BIT', 'NETCDF3_CLASSIC'}, optional
File format for the resulting netCDF file:

Expand Down
158 changes: 76 additions & 82 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,42 @@ class Only32BitTypes(object):

class DatasetIOTestCases(object):
autoclose = False
engine = None
file_format = None

def create_store(self):
raise NotImplementedError

def roundtrip(self, data, **kwargs):
raise NotImplementedError
@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file(
allow_cleanup_failure=allow_cleanup_failure) as path:
self.save(data, path, **save_kwargs)
with self.open(path, **open_kwargs) as ds:
yield ds

@contextlib.contextmanager
def roundtrip_append(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file(
allow_cleanup_failure=allow_cleanup_failure) as path:
for i, key in enumerate(data.variables):
mode = 'a' if i > 0 else 'w'
self.save(data[[key]], path, mode=mode, **save_kwargs)
with self.open(path, **open_kwargs) as ds:
yield ds

# The save/open methods may be overwritten below
def save(self, dataset, path, **kwargs):
dataset.to_netcdf(path, engine=self.engine, format=self.file_format,
**kwargs)

@contextlib.contextmanager
def open(self, path, **kwargs):
with open_dataset(path, engine=self.engine, autoclose=self.autoclose,
**kwargs) as ds:
yield ds

def test_zero_dimensional_variable(self):
expected = create_test_data()
Expand Down Expand Up @@ -563,6 +593,23 @@ def test_encoding_same_dtype(self):
self.assertEqual(actual.x.encoding['dtype'], 'f4')
self.assertEqual(ds.x.encoding, {})

def test_append_write(self):
# regression for GH1215
data = create_test_data()
with self.roundtrip_append(data) as actual:
assert_allclose(data, actual)

def test_append_overwrite_values(self):
# regression for GH1215
data = create_test_data()
with create_tmp_file(allow_cleanup_failure=False) as tmp_file:
self.save(data, tmp_file, mode='w')
data['var2'][:] = -999
data['var9'] = data['var2'] * 3
self.save(data[['var2', 'var9']], tmp_file, mode='a')
with self.open(tmp_file) as actual:
assert_allclose(data, actual)


_counter = itertools.count()

Expand Down Expand Up @@ -592,6 +639,9 @@ def create_tmp_files(nfiles, suffix='.nc', allow_cleanup_failure=False):

@requires_netCDF4
class BaseNetCDF4Test(CFEncodedDataTest):

engine = 'netcdf4'

def test_open_group(self):
# Create a netCDF file with a dataset stored within a group
with create_tmp_file() as tmp_file:
Expand Down Expand Up @@ -813,16 +863,6 @@ def create_store(self):
with backends.NetCDF4DataStore.open(tmp_file, mode='w') as store:
yield store

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file(
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
data.to_netcdf(tmp_file, **save_kwargs)
with open_dataset(tmp_file,
autoclose=self.autoclose, **open_kwargs) as ds:
yield ds

def test_variable_order(self):
# doesn't work with scipy or h5py :(
ds = Dataset()
Expand Down Expand Up @@ -883,19 +923,13 @@ class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest):

@requires_scipy
class ScipyInMemoryDataTest(CFEncodedDataTest, Only32BitTypes, TestCase):
engine = 'scipy'

@contextlib.contextmanager
def create_store(self):
fobj = BytesIO()
yield backends.ScipyDataStore(fobj, 'w')

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
serialized = data.to_netcdf(**save_kwargs)
with open_dataset(serialized, engine='scipy',
autoclose=self.autoclose, **open_kwargs) as ds:
yield ds

def test_to_netcdf_explicit_engine(self):
# regression test for GH1321
Dataset({'foo': 42}).to_netcdf(engine='scipy')
Expand All @@ -915,6 +949,8 @@ class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest):

@requires_scipy
class ScipyFileObjectTest(CFEncodedDataTest, Only32BitTypes, TestCase):
engine = 'scipy'

@contextlib.contextmanager
def create_store(self):
fobj = BytesIO()
Expand All @@ -925,9 +961,9 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file() as tmp_file:
with open(tmp_file, 'wb') as f:
data.to_netcdf(f, **save_kwargs)
self.save(data, f, **save_kwargs)
with open(tmp_file, 'rb') as f:
with open_dataset(f, engine='scipy', **open_kwargs) as ds:
with self.open(f, **open_kwargs) as ds:
yield ds

@pytest.mark.skip(reason='cannot pickle file objects')
Expand All @@ -941,22 +977,14 @@ def test_pickle_dataarray(self):

@requires_scipy
class ScipyFilePathTest(CFEncodedDataTest, Only32BitTypes, TestCase):
engine = 'scipy'

@contextlib.contextmanager
def create_store(self):
with create_tmp_file() as tmp_file:
with backends.ScipyDataStore(tmp_file, mode='w') as store:
yield store

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file(
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
data.to_netcdf(tmp_file, engine='scipy', **save_kwargs)
with open_dataset(tmp_file, engine='scipy',
autoclose=self.autoclose, **open_kwargs) as ds:
yield ds

def test_array_attrs(self):
ds = Dataset(attrs={'foo': [[1, 2], [3, 4]]})
with self.assertRaisesRegexp(ValueError, 'must be 1-dimensional'):
Expand Down Expand Up @@ -995,24 +1023,16 @@ class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest):

@requires_netCDF4
class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase):
engine = 'netcdf4'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite enough -- you also need to set the format in this case. Take a look at roundtrip() below.

Probably the cleanest fix would be refactor roundtrip() into three methods:

     # on the base class
     @contextlib.contextmanager
     def roundtrip(self, data, save_kwargs={}, open_kwargs={},
                   allow_cleanup_failure=False):
         with create_tmp_file(
                 allow_cleanup_failure=allow_cleanup_failure) as path:
             self.save(data, path, **save_kwargs)
             with self.open(path, **open_kwargs) as ds:
                 yield ds

     # on subclasses, e.g., for NetCDF3ViaNetCDF4DataTest
     def save(self, dataset, path, **kwargs):
         dataset.to_netcdf(tmp_file, format='NETCDF3_CLASSIC',
                           engine='netcdf4', **kwargs)

     @contextlib.contextmanager
     def open(self, path, **kwargs):
          with open_dataset(tmp_file, engine='netcdf4',
                            autoclose=self.autoclose, **open_kwargs) as ds:
               yield ds

Then you could write roundtrip_append() in terms of save and open.

file_format = 'NETCDF3_CLASSIC'

@contextlib.contextmanager
def create_store(self):
with create_tmp_file() as tmp_file:
with backends.NetCDF4DataStore.open(
tmp_file, mode='w', format='NETCDF3_CLASSIC') as store:
yield store

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file(
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
data.to_netcdf(tmp_file, format='NETCDF3_CLASSIC',
engine='netcdf4', **save_kwargs)
with open_dataset(tmp_file, engine='netcdf4',
autoclose=self.autoclose, **open_kwargs) as ds:
yield ds


class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest):
autoclose = True
Expand All @@ -1021,24 +1041,16 @@ class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest):
@requires_netCDF4
class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes,
TestCase):
engine = 'netcdf4'
file_format = 'NETCDF4_CLASSIC'

@contextlib.contextmanager
def create_store(self):
with create_tmp_file() as tmp_file:
with backends.NetCDF4DataStore.open(
tmp_file, mode='w', format='NETCDF4_CLASSIC') as store:
yield store

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file(
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
data.to_netcdf(tmp_file, format='NETCDF4_CLASSIC',
engine='netcdf4', **save_kwargs)
with open_dataset(tmp_file, engine='netcdf4',
autoclose=self.autoclose, **open_kwargs) as ds:
yield ds


class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue(
NetCDF4ClassicViaNetCDF4DataTest):
Expand All @@ -1049,21 +1061,12 @@ class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue(
class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, TestCase):
# verify that we can read and write netCDF3 files as long as we have scipy
# or netCDF4-python installed
file_format = 'netcdf3_64bit'

def test_write_store(self):
# there's no specific store to test here
pass

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file(
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
data.to_netcdf(tmp_file, format='netcdf3_64bit', **save_kwargs)
with open_dataset(tmp_file,
autoclose=self.autoclose, **open_kwargs) as ds:
yield ds

def test_engine(self):
data = create_test_data()
with self.assertRaisesRegexp(ValueError, 'unrecognized engine'):
Expand Down Expand Up @@ -1122,21 +1125,13 @@ class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest):
@requires_h5netcdf
@requires_netCDF4
class H5NetCDFDataTest(BaseNetCDF4Test, TestCase):
engine = 'h5netcdf'

@contextlib.contextmanager
def create_store(self):
with create_tmp_file() as tmp_file:
yield backends.H5NetCDFStore(tmp_file, 'w')

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file(
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
data.to_netcdf(tmp_file, engine='h5netcdf', **save_kwargs)
with open_dataset(tmp_file, engine='h5netcdf',
autoclose=self.autoclose, **open_kwargs) as ds:
yield ds

def test_orthogonal_indexing(self):
# doesn't work for h5py (without using dask as an intermediate layer)
pass
Expand Down Expand Up @@ -1646,14 +1641,13 @@ def test_orthogonal_indexing(self):
pass

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={},
allow_cleanup_failure=False):
with create_tmp_file(
allow_cleanup_failure=allow_cleanup_failure) as tmp_file:
data.to_netcdf(tmp_file, engine='scipy', **save_kwargs)
with open_dataset(tmp_file, engine='pynio',
autoclose=self.autoclose, **open_kwargs) as ds:
yield ds
def open(self, path, **kwargs):
with open_dataset(path, engine='pynio', autoclose=self.autoclose,
**kwargs) as ds:
yield ds

def save(self, dataset, path, **kwargs):
dataset.to_netcdf(path, engine='scipy', **kwargs)

def test_weakrefs(self):
example = Dataset({'foo': ('x', np.arange(5.0))})
Expand Down