diff --git a/doc/dask.rst b/doc/dask.rst index adf0a6bf585..19cbc11292c 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -75,13 +75,14 @@ entirely equivalent to opening a dataset using ``open_dataset`` and then chunking the data using the ``chunk`` method, e.g., ``xr.open_dataset('example-data.nc').chunk({'time': 10})``. -To open multiple files simultaneously, use :py:func:`~xarray.open_mfdataset`:: +To open multiple files simultaneously in parallel using Dask delayed, +use :py:func:`~xarray.open_mfdataset`:: - xr.open_mfdataset('my/files/*.nc') + xr.open_mfdataset('my/files/*.nc', parallel=True) This function will automatically concatenate and merge dataset into one in the simple cases that it understands (see :py:func:`~xarray.auto_combine` -for the full disclaimer). By default, ``open_mfdataset`` will chunk each +for the full disclaimer). By default, :py:func:`~xarray.open_mfdataset` will chunk each netCDF file into a single Dask array; again, supply the ``chunks`` argument to control the size of the resulting Dask arrays. In more complex cases, you can open each file individually using ``open_dataset`` and merge the result, as diff --git a/doc/io.rst b/doc/io.rst index f7ac8c095b9..775d915188e 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -99,7 +99,9 @@ netCDF The recommended way to store xarray data structures is `netCDF`__, which is a binary file format for self-described datasets that originated in the geosciences. xarray is based on the netCDF data model, so netCDF files -on disk directly correspond to :py:class:`~xarray.Dataset` objects. +on disk directly correspond to :py:class:`~xarray.Dataset` objects (more accurately, +a group in a netCDF file directly corresponds to a to :py:class:`~xarray.Dataset` object. +See :ref:`io.netcdf_groups` for more.) NetCDF is supported on almost all platforms, and parsers exist for the vast majority of scientific programming languages. Recent versions of @@ -121,7 +123,7 @@ read/write netCDF V4 files and use the compression options described below). __ https://github.com/Unidata/netcdf4-python We can save a Dataset to disk using the -:py:attr:`Dataset.to_netcdf ` method: +:py:meth:`~Dataset.to_netcdf` method: .. ipython:: python @@ -147,19 +149,6 @@ convert the ``DataArray`` to a ``Dataset`` before saving, and then convert back when loading, ensuring that the ``DataArray`` that is loaded is always exactly the same as the one that was saved. -NetCDF groups are not supported as part of the -:py:class:`~xarray.Dataset` data model. Instead, groups can be loaded -individually as Dataset objects. -To do so, pass a ``group`` keyword argument to the -``open_dataset`` function. The group can be specified as a path-like -string, e.g., to access subgroup 'bar' within group 'foo' pass -'/foo/bar' as the ``group`` argument. -In a similar way, the ``group`` keyword argument can be given to the -:py:meth:`~xarray.Dataset.to_netcdf` method to write to a group -in a netCDF file. -When writing multiple groups in one file, pass ``mode='a'`` to ``to_netcdf`` -to ensure that each call does not delete the file. - Data is always loaded lazily from netCDF files. You can manipulate, slice and subset Dataset and DataArray objects, and no array values are loaded into memory until you try to perform some sort of actual computation. For an example of how these @@ -195,6 +184,24 @@ It is possible to append or overwrite netCDF variables using the ``mode='a'`` argument. When using this option, all variables in the dataset will be written to the original netCDF file, regardless if they exist in the original dataset. + +.. _io.netcdf_groups: + +Groups +~~~~~~ + +NetCDF groups are not supported as part of the :py:class:`~xarray.Dataset` data model. +Instead, groups can be loaded individually as Dataset objects. +To do so, pass a ``group`` keyword argument to the +:py:func:`~xarray.open_dataset` function. The group can be specified as a path-like +string, e.g., to access subgroup ``'bar'`` within group ``'foo'`` pass +``'/foo/bar'`` as the ``group`` argument. +In a similar way, the ``group`` keyword argument can be given to the +:py:meth:`~xarray.Dataset.to_netcdf` method to write to a group +in a netCDF file. +When writing multiple groups in one file, pass ``mode='a'`` to +:py:meth:`~xarray.Dataset.to_netcdf` to ensure that each call does not delete the file. + .. _io.encoding: Reading encoded data @@ -203,7 +210,7 @@ Reading encoded data NetCDF files follow some conventions for encoding datetime arrays (as numbers with a "units" attribute) and for packing and unpacking data (as described by the "scale_factor" and "add_offset" attributes). If the argument -``decode_cf=True`` (default) is given to ``open_dataset``, xarray will attempt +``decode_cf=True`` (default) is given to :py:func:`~xarray.open_dataset`, xarray will attempt to automatically decode the values in the netCDF objects according to `CF conventions`_. Sometimes this will fail, for example, if a variable has an invalid "units" or "calendar" attribute. For these cases, you can @@ -247,6 +254,130 @@ will remove encoding information. import os os.remove('saved_on_disk.nc') + +.. _combining multiple files: + +Reading multi-file datasets +........................... + +NetCDF files are often encountered in collections, e.g., with different files +corresponding to different model runs or one file per timestamp. +xarray can straightforwardly combine such files into a single Dataset by making use of +:py:func:`~xarray.concat`, :py:func:`~xarray.merge`, :py:func:`~xarray.combine_nested` and +:py:func:`~xarray.combine_by_coords`. For details on the difference between these +functions see :ref:`combining data`. + +Xarray includes support for manipulating datasets that don't fit into memory +with dask_. If you have dask installed, you can open multiple files +simultaneously in parallel using :py:func:`~xarray.open_mfdataset`:: + + xr.open_mfdataset('my/files/*.nc', parallel=True) + +This function automatically concatenates and merges multiple files into a +single xarray dataset. +It is the recommended way to open multiple files with xarray. +For more details on parallel reading, see :ref:`combining.multi`, :ref:`dask.io` and a +`blog post`_ by Stephan Hoyer. +:py:func:`~xarray.open_mfdataset` takes many kwargs that allow you to +control its behaviour (for e.g. ``parallel``, ``combine``, ``compat``, ``join``, ``concat_dim``). +See its docstring for more details. + + +.. note:: + + A common use-case involves a dataset distributed across a large number of files with + each file containing a large number of variables. Commonly a few of these variables + need to be concatenated along a dimension (say ``"time"``), while the rest are equal + across the datasets (ignoring floating point differences). The following command + with suitable modifications (such as ``parallel=True``) works well with such datasets:: + + xr.open_mfdataset('my/files/*.nc', concat_dim="time", + data_vars='minimal', coords='minimal', compat='override') + + This command concatenates variables along the ``"time"`` dimension, but only those that + already contain the ``"time"`` dimension (``data_vars='minimal', coords='minimal'``). + Variables that lack the ``"time"`` dimension are taken from the first dataset + (``compat='override'``). + + +.. _dask: http://dask.pydata.org +.. _blog post: http://stephanhoyer.com/2015/06/11/xray-dask-out-of-core-labeled-arrays/ + +Sometimes multi-file datasets are not conveniently organized for easy use of :py:func:`~xarray.open_mfdataset`. +One can use the ``preprocess`` argument to provide a function that takes a dataset +and returns a modified Dataset. +:py:func:`~xarray.open_mfdataset` will call ``preprocess`` on every dataset +(corresponding to each file) prior to combining them. + + +If :py:func:`~xarray.open_mfdataset` does not meet your needs, other approaches are possible. +The general pattern for parallel reading of multiple files +using dask, modifying those datasets and then combining into a single ``Dataset`` is:: + + def modify(ds): + # modify ds here + return ds + + + # this is basically what open_mfdataset does + open_kwargs = dict(decode_cf=True, decode_times=False) + open_tasks = [dask.delayed(xr.open_dataset)(f, **open_kwargs) for f in file_names] + tasks = [dask.delayed(modify)(task) for task in open_tasks] + datasets = dask.compute(tasks) # get a list of xarray.Datasets + combined = xr.combine_nested(datasets) # or some combination of concat, merge + + +As an example, here's how we could approximate ``MFDataset`` from the netCDF4 +library:: + + from glob import glob + import xarray as xr + + def read_netcdfs(files, dim): + # glob expands paths with * to a list of files, like the unix shell + paths = sorted(glob(files)) + datasets = [xr.open_dataset(p) for p in paths] + combined = xr.concat(dataset, dim) + return combined + + combined = read_netcdfs('/all/my/files/*.nc', dim='time') + +This function will work in many cases, but it's not very robust. First, it +never closes files, which means it will fail one you need to load more than +a few thousands file. Second, it assumes that you want all the data from each +file and that it can all fit into memory. In many situations, you only need +a small subset or an aggregated summary of the data from each file. + +Here's a slightly more sophisticated example of how to remedy these +deficiencies:: + + def read_netcdfs(files, dim, transform_func=None): + def process_one_path(path): + # use a context manager, to ensure the file gets closed after use + with xr.open_dataset(path) as ds: + # transform_func should do some sort of selection or + # aggregation + if transform_func is not None: + ds = transform_func(ds) + # load all data from the transformed dataset, to ensure we can + # use it after closing each original file + ds.load() + return ds + + paths = sorted(glob(files)) + datasets = [process_one_path(p) for p in paths] + combined = xr.concat(datasets, dim) + return combined + + # here we suppose we only care about the combined mean of each file; + # you might also use indexing operations like .sel to subset datasets + combined = read_netcdfs('/all/my/files/*.nc', dim='time', + transform_func=lambda ds: ds.mean()) + +This pattern works well and is very robust. We've used similar code to process +tens of thousands of files constituting 100s of GB of data. + + .. _io.netcdf.writing_encoded: Writing encoded data @@ -817,84 +948,3 @@ For CSV files, one might also consider `xarray_extras`_. .. _xarray_extras: https://xarray-extras.readthedocs.io/en/latest/api/csv.html .. _IO tools: http://pandas.pydata.org/pandas-docs/stable/io.html - - -.. _combining multiple files: - - -Combining multiple files ------------------------- - -NetCDF files are often encountered in collections, e.g., with different files -corresponding to different model runs. xarray can straightforwardly combine such -files into a single Dataset by making use of :py:func:`~xarray.concat`, -:py:func:`~xarray.merge`, :py:func:`~xarray.combine_nested` and -:py:func:`~xarray.combine_by_coords`. For details on the difference between these -functions see :ref:`combining data`. - -.. note:: - - Xarray includes support for manipulating datasets that don't fit into memory - with dask_. If you have dask installed, you can open multiple files - simultaneously using :py:func:`~xarray.open_mfdataset`:: - - xr.open_mfdataset('my/files/*.nc') - - This function automatically concatenates and merges multiple files into a - single xarray dataset. - It is the recommended way to open multiple files with xarray. - For more details, see :ref:`combining.multi`, :ref:`dask.io` and a - `blog post`_ by Stephan Hoyer. - -.. _dask: http://dask.pydata.org -.. _blog post: http://stephanhoyer.com/2015/06/11/xray-dask-out-of-core-labeled-arrays/ - -For example, here's how we could approximate ``MFDataset`` from the netCDF4 -library:: - - from glob import glob - import xarray as xr - - def read_netcdfs(files, dim): - # glob expands paths with * to a list of files, like the unix shell - paths = sorted(glob(files)) - datasets = [xr.open_dataset(p) for p in paths] - combined = xr.concat(dataset, dim) - return combined - - combined = read_netcdfs('/all/my/files/*.nc', dim='time') - -This function will work in many cases, but it's not very robust. First, it -never closes files, which means it will fail one you need to load more than -a few thousands file. Second, it assumes that you want all the data from each -file and that it can all fit into memory. In many situations, you only need -a small subset or an aggregated summary of the data from each file. - -Here's a slightly more sophisticated example of how to remedy these -deficiencies:: - - def read_netcdfs(files, dim, transform_func=None): - def process_one_path(path): - # use a context manager, to ensure the file gets closed after use - with xr.open_dataset(path) as ds: - # transform_func should do some sort of selection or - # aggregation - if transform_func is not None: - ds = transform_func(ds) - # load all data from the transformed dataset, to ensure we can - # use it after closing each original file - ds.load() - return ds - - paths = sorted(glob(files)) - datasets = [process_one_path(p) for p in paths] - combined = xr.concat(datasets, dim) - return combined - - # here we suppose we only care about the combined mean of each file; - # you might also use indexing operations like .sel to subset datasets - combined = read_netcdfs('/all/my/files/*.nc', dim='time', - transform_func=lambda ds: ds.mean()) - -This pattern works well and is very robust. We've used similar code to process -tens of thousands of files constituting 100s of GB of data. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ab4b17ff16d..492c9279e6b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -93,7 +93,7 @@ New functions/methods By `Deepak Cherian `_ and `David Mertz `_. -- Dataset plotting API for visualizing dependencies between two `DataArray`s! +- Dataset plotting API for visualizing dependencies between two DataArrays! Currently only :py:meth:`Dataset.plot.scatter` is implemented. By `Yohai Bar Sinai `_ and `Deepak Cherian `_ @@ -103,11 +103,30 @@ New functions/methods Enhancements ~~~~~~~~~~~~ -- Added ``join='override'``. This only checks that index sizes are equal among objects and skips - checking indexes for equality. By `Deepak Cherian `_. +- Multiple enhancements to :py:func:`~xarray.concat` and :py:func:`~xarray.open_mfdataset`. -- :py:func:`~xarray.concat` and :py:func:`~xarray.open_mfdataset` now support the ``join`` kwarg. - It is passed down to :py:func:`~xarray.align`. By `Deepak Cherian `_. + - Added ``compat='override'``. When merging, this option picks the variable from the first dataset + and skips all comparisons. + + - Added ``join='override'``. When aligning, this only checks that index sizes are equal among objects + and skips checking indexes for equality. + + - :py:func:`~xarray.concat` and :py:func:`~xarray.open_mfdataset` now support the ``join`` kwarg. + It is passed down to :py:func:`~xarray.align`. + + - :py:func:`~xarray.concat` now calls :py:func:`~xarray.merge` on variables that are not concatenated + (i.e. variables without ``concat_dim`` when ``data_vars`` or ``coords`` are ``"minimal"``). + :py:func:`~xarray.concat` passes its new ``compat`` kwarg down to :py:func:`~xarray.merge`. + (:issue:`2064`) + + Users can avoid a common bottleneck when using :py:func:`~xarray.open_mfdataset` on a large number of + files with variables that are known to be aligned and some of which need not be concatenated. + Slow equality comparisons can now be avoided, for e.g.:: + + data = xr.open_mfdataset(files, concat_dim='time', data_vars='minimal', + coords='minimal', compat='override', join='override') + + By `Deepak Cherian `_: - In :py:meth:`~xarray.Dataset.to_zarr`, passing ``mode`` is not mandatory if ``append_dim`` is set, as it will automatically be set to ``'a'`` internally. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index a20d3c2a306..1f0869cfc53 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -761,7 +761,7 @@ def open_mfdataset( `xarray.auto_combine` is used, but in the future this behavior will switch to use `xarray.combine_by_coords` by default. compat : {'identical', 'equals', 'broadcast_equals', - 'no_conflicts'}, optional + 'no_conflicts', 'override'}, optional String indicating how to compare variables of the same name for potential conflicts when merging: * 'broadcast_equals': all values must be equal when variables are @@ -772,6 +772,7 @@ def open_mfdataset( * 'no_conflicts': only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. + * 'override': skip comparing and pick variable from first dataset preprocess : callable, optional If provided, call this function on each dataset prior to concatenation. You can find the file-name from which each dataset was loaded in diff --git a/xarray/core/combine.py b/xarray/core/combine.py index c24be88b19e..e35bb51e030 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -243,6 +243,7 @@ def _combine_1d( dim=concat_dim, data_vars=data_vars, coords=coords, + compat=compat, fill_value=fill_value, join=join, ) @@ -351,7 +352,7 @@ def combine_nested( Must be the same length as the depth of the list passed to ``datasets``. compat : {'identical', 'equals', 'broadcast_equals', - 'no_conflicts'}, optional + 'no_conflicts', 'override'}, optional String indicating how to compare variables of the same name for potential merge conflicts: @@ -363,6 +364,7 @@ def combine_nested( - 'no_conflicts': only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. + - 'override': skip comparing and pick variable from first dataset data_vars : {'minimal', 'different', 'all' or list of str}, optional Details are in the documentation of concat coords : {'minimal', 'different', 'all' or list of str}, optional @@ -504,7 +506,7 @@ def combine_by_coords( datasets : sequence of xarray.Dataset Dataset objects to combine. compat : {'identical', 'equals', 'broadcast_equals', - 'no_conflicts'}, optional + 'no_conflicts', 'override'}, optional String indicating how to compare variables of the same name for potential conflicts: @@ -516,6 +518,7 @@ def combine_by_coords( - 'no_conflicts': only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. + - 'override': skip comparing and pick variable from first dataset data_vars : {'minimal', 'different', 'all' or list of str}, optional Details are in the documentation of concat coords : {'minimal', 'different', 'all' or list of str}, optional @@ -598,6 +601,7 @@ def combine_by_coords( concat_dims=concat_dims, data_vars=data_vars, coords=coords, + compat=compat, fill_value=fill_value, join=join, ) @@ -667,7 +671,7 @@ def auto_combine( component files. Set ``concat_dim=None`` explicitly to disable concatenation. compat : {'identical', 'equals', 'broadcast_equals', - 'no_conflicts'}, optional + 'no_conflicts', 'override'}, optional String indicating how to compare variables of the same name for potential conflicts: - 'broadcast_equals': all values must be equal when variables are @@ -678,6 +682,7 @@ def auto_combine( - 'no_conflicts': only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. + - 'override': skip comparing and pick variable from first dataset data_vars : {'minimal', 'different', 'all' or list of str}, optional Details are in the documentation of concat coords : {'minimal', 'different', 'all' o list of str}, optional @@ -832,6 +837,7 @@ def _old_auto_combine( dim=dim, data_vars=data_vars, coords=coords, + compat=compat, fill_value=fill_value, join=join, ) @@ -850,6 +856,7 @@ def _auto_concat( coords="different", fill_value=dtypes.NA, join="outer", + compat="no_conflicts", ): if len(datasets) == 1 and dim is None: # There is nothing more to combine, so kick out early. @@ -876,5 +883,10 @@ def _auto_concat( ) dim, = concat_dims return concat( - datasets, dim=dim, data_vars=data_vars, coords=coords, fill_value=fill_value + datasets, + dim=dim, + data_vars=data_vars, + coords=coords, + fill_value=fill_value, + compat=compat, ) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index d5dfa49a8d5..e68c247d880 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -4,6 +4,7 @@ from . import dtypes, utils from .alignment import align +from .merge import unique_variable, _VALID_COMPAT from .variable import IndexVariable, Variable, as_variable from .variable import concat as concat_vars @@ -59,12 +60,19 @@ def concat( those corresponding to other dimensions. * list of str: The listed coordinate variables will be concatenated, in addition to the 'minimal' coordinates. - compat : {'equals', 'identical'}, optional - String indicating how to compare non-concatenated variables and - dataset global attributes for potential conflicts. 'equals' means - that all variable values and dimensions must be the same; - 'identical' means that variable attributes and global attributes - must also be equal. + compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional + String indicating how to compare non-concatenated variables of the same name for + potential conflicts. This is passed down to merge. + + - 'broadcast_equals': all values must be equal when variables are + broadcast against each other to ensure common dimensions. + - 'equals': all values and dimensions must be the same. + - 'identical': all values, dimensions and attributes must be the + same. + - 'no_conflicts': only values which are not null in both datasets + must be equal. The returned dataset then contains the combination + of all non-null values. + - 'override': skip comparing and pick variable from first dataset positions : None or list of integer arrays, optional List of integer arrays which specifies the integer positions to which to assign each dataset along the concatenated dimension. If not @@ -107,6 +115,12 @@ def concat( except StopIteration: raise ValueError("must supply at least one object to concatenate") + if compat not in _VALID_COMPAT: + raise ValueError( + "compat=%r invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'" + % compat + ) + if isinstance(first_obj, DataArray): f = _dataarray_concat elif isinstance(first_obj, Dataset): @@ -143,23 +157,39 @@ def _calc_concat_dim_coord(dim): return dim, coord -def _calc_concat_over(datasets, dim, data_vars, coords): +def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat): """ Determine which dataset variables need to be concatenated in the result, - and which can simply be taken from the first dataset. """ # Return values concat_over = set() equals = {} - if dim in datasets[0]: + if dim in dim_names: + concat_over_existing_dim = True concat_over.add(dim) + else: + concat_over_existing_dim = False + + concat_dim_lengths = [] for ds in datasets: + if concat_over_existing_dim: + if dim not in ds.dims: + if dim in ds: + ds = ds.set_coords(dim) + else: + raise ValueError("%r is not present in all datasets" % dim) concat_over.update(k for k, v in ds.variables.items() if dim in v.dims) + concat_dim_lengths.append(ds.dims.get(dim, 1)) def process_subset_opt(opt, subset): if isinstance(opt, str): if opt == "different": + if compat == "override": + raise ValueError( + "Cannot specify both %s='different' and compat='override'." + % subset + ) # all nonindexes that are not the same in each dataset for k in getattr(datasets[0], subset): if k not in concat_over: @@ -173,7 +203,7 @@ def process_subset_opt(opt, subset): for ds_rhs in datasets[1:]: v_rhs = ds_rhs.variables[k].compute() computed.append(v_rhs) - if not v_lhs.equals(v_rhs): + if not getattr(v_lhs, compat)(v_rhs): concat_over.add(k) equals[k] = False # computed variables are not to be re-computed @@ -209,7 +239,29 @@ def process_subset_opt(opt, subset): process_subset_opt(data_vars, "data_vars") process_subset_opt(coords, "coords") - return concat_over, equals + return concat_over, equals, concat_dim_lengths + + +# determine dimensional coordinate names and a dict mapping name to DataArray +def _parse_datasets(datasets): + + dims = set() + all_coord_names = set() + data_vars = set() # list of data_vars + dim_coords = dict() # maps dim name to variable + dims_sizes = {} # shared dimension sizes to expand variables + + for ds in datasets: + dims_sizes.update(ds.dims) + all_coord_names.update(ds.coords) + data_vars.update(ds.data_vars) + + for dim in set(ds.dims) - dims: + if dim not in dim_coords: + dim_coords[dim] = ds.coords[dim].variable + dims = dims | set(ds.dims) + + return dim_coords, dims_sizes, all_coord_names, data_vars def _dataset_concat( @@ -227,11 +279,6 @@ def _dataset_concat( """ from .dataset import Dataset - if compat not in ["equals", "identical"]: - raise ValueError( - "compat=%r invalid: must be 'equals' " "or 'identical'" % compat - ) - dim, coord = _calc_concat_dim_coord(dim) # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] @@ -239,62 +286,65 @@ def _dataset_concat( *datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value ) - concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords) + dim_coords, dims_sizes, coord_names, data_names = _parse_datasets(datasets) + dim_names = set(dim_coords) + unlabeled_dims = dim_names - coord_names + + both_data_and_coords = coord_names & data_names + if both_data_and_coords: + raise ValueError( + "%r is a coordinate in some datasets but not others." % both_data_and_coords + ) + # we don't want the concat dimension in the result dataset yet + dim_coords.pop(dim, None) + dims_sizes.pop(dim, None) + + # case where concat dimension is a coordinate or data_var but not a dimension + if (dim in coord_names or dim in data_names) and dim not in dim_names: + datasets = [ds.expand_dims(dim) for ds in datasets] + + # determine which variables to concatentate + concat_over, equals, concat_dim_lengths = _calc_concat_over( + datasets, dim, dim_names, data_vars, coords, compat + ) + + # determine which variables to merge, and then merge them according to compat + variables_to_merge = (coord_names | data_names) - concat_over - dim_names + + result_vars = {} + if variables_to_merge: + to_merge = {var: [] for var in variables_to_merge} + + for ds in datasets: + absent_merge_vars = variables_to_merge - set(ds.variables) + if absent_merge_vars: + raise ValueError( + "variables %r are present in some datasets but not others. " + % absent_merge_vars + ) - def insert_result_variable(k, v): - assert isinstance(v, Variable) - if k in datasets[0].coords: - result_coord_names.add(k) - result_vars[k] = v + for var in variables_to_merge: + to_merge[var].append(ds.variables[var]) - # create the new dataset and add constant variables - result_vars = OrderedDict() - result_coord_names = set(datasets[0].coords) + for var in variables_to_merge: + result_vars[var] = unique_variable( + var, to_merge[var], compat=compat, equals=equals.get(var, None) + ) + else: + result_vars = OrderedDict() + result_vars.update(dim_coords) + + # assign attrs and encoding from first dataset result_attrs = datasets[0].attrs result_encoding = datasets[0].encoding - for k, v in datasets[0].variables.items(): - if k not in concat_over: - insert_result_variable(k, v) - - # check that global attributes and non-concatenated variables are fixed - # across all datasets + # check that global attributes are fixed across all datasets if necessary for ds in datasets[1:]: if compat == "identical" and not utils.dict_equiv(ds.attrs, result_attrs): - raise ValueError("dataset global attributes not equal") - for k, v in ds.variables.items(): - if k not in result_vars and k not in concat_over: - raise ValueError("encountered unexpected variable %r" % k) - elif (k in result_coord_names) != (k in ds.coords): - raise ValueError( - "%r is a coordinate in some datasets but not " "others" % k - ) - elif k in result_vars and k != dim: - # Don't use Variable.identical as it internally invokes - # Variable.equals, and we may already know the answer - if compat == "identical" and not utils.dict_equiv( - v.attrs, result_vars[k].attrs - ): - raise ValueError("variable %s not identical across datasets" % k) - - # Proceed with equals() - try: - # May be populated when using the "different" method - is_equal = equals[k] - except KeyError: - result_vars[k].load() - is_equal = v.equals(result_vars[k]) - if not is_equal: - raise ValueError("variable %s not equal across datasets" % k) + raise ValueError("Dataset global attributes not equal.") # we've already verified everything is consistent; now, calculate # shared dimension sizes so we can expand the necessary variables - dim_lengths = [ds.dims.get(dim, 1) for ds in datasets] - non_concat_dims = {} - for ds in datasets: - non_concat_dims.update(ds.dims) - non_concat_dims.pop(dim, None) - def ensure_common_dims(vars): # ensure each variable with the given name shares the same # dimensions and the same shape for all of them except along the @@ -302,25 +352,27 @@ def ensure_common_dims(vars): common_dims = tuple(pd.unique([d for v in vars for d in v.dims])) if dim not in common_dims: common_dims = (dim,) + common_dims - for var, dim_len in zip(vars, dim_lengths): + for var, dim_len in zip(vars, concat_dim_lengths): if var.dims != common_dims: - common_shape = tuple( - non_concat_dims.get(d, dim_len) for d in common_dims - ) + common_shape = tuple(dims_sizes.get(d, dim_len) for d in common_dims) var = var.set_dims(common_dims, common_shape) yield var # stack up each variable to fill-out the dataset (in order) + # n.b. this loop preserves variable order, needed for groupby. for k in datasets[0].variables: if k in concat_over: vars = ensure_common_dims([ds.variables[k] for ds in datasets]) combined = concat_vars(vars, dim, positions) - insert_result_variable(k, combined) + assert isinstance(combined, Variable) + result_vars[k] = combined result = Dataset(result_vars, attrs=result_attrs) - result = result.set_coords(result_coord_names) + result = result.set_coords(coord_names) result.encoding = result_encoding + result = result.drop(unlabeled_dims, errors="ignore") + if coord is not None: # add concat dimension last to ensure that its in the final Dataset result[coord.name] = coord @@ -342,7 +394,7 @@ def _dataarray_concat( if data_vars != "all": raise ValueError( - "data_vars is not a valid argument when " "concatenating DataArray objects" + "data_vars is not a valid argument when concatenating DataArray objects" ) datasets = [] diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7937a352cc6..d9e98839419 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1551,8 +1551,8 @@ def set_index( obj : DataArray Another DataArray, with this data but replaced coordinates. - Example - ------- + Examples + -------- >>> arr = xr.DataArray(data=np.ones((2, 3)), ... dims=['x', 'y'], ... coords={'x': diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 225507b9204..6dba659f992 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -44,6 +44,7 @@ "broadcast_equals": 2, "minimal": 3, "no_conflicts": 4, + "override": 5, } ) @@ -70,8 +71,8 @@ class MergeError(ValueError): # TODO: move this to an xarray.exceptions module? -def unique_variable(name, variables, compat="broadcast_equals"): - # type: (Any, List[Variable], str) -> Variable +def unique_variable(name, variables, compat="broadcast_equals", equals=None): + # type: (Any, List[Variable], str, bool) -> Variable """Return the unique variable from a list of variables or raise MergeError. Parameters @@ -81,8 +82,10 @@ def unique_variable(name, variables, compat="broadcast_equals"): variables : list of xarray.Variable List of Variable objects, all of which go by the same name in different inputs. - compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts'}, optional + compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional Type of equality check to use. + equals: None or bool, + corresponding to result of compat test Returns ------- @@ -93,30 +96,38 @@ def unique_variable(name, variables, compat="broadcast_equals"): MergeError: if any of the variables are not equal. """ # noqa out = variables[0] - if len(variables) > 1: - combine_method = None - if compat == "minimal": - compat = "broadcast_equals" + if len(variables) == 1 or compat == "override": + return out + + combine_method = None + + if compat == "minimal": + compat = "broadcast_equals" + + if compat == "broadcast_equals": + dim_lengths = broadcast_dimension_size(variables) + out = out.set_dims(dim_lengths) + + if compat == "no_conflicts": + combine_method = "fillna" - if compat == "broadcast_equals": - dim_lengths = broadcast_dimension_size(variables) - out = out.set_dims(dim_lengths) + if equals is None: + out = out.compute() + for var in variables[1:]: + equals = getattr(out, compat)(var) + if not equals: + break - if compat == "no_conflicts": - combine_method = "fillna" + if not equals: + raise MergeError( + "conflicting values for variable %r on objects to be combined. You can skip this check by specifying compat='override'." + % (name) + ) + if combine_method: for var in variables[1:]: - if not getattr(out, compat)(var): - raise MergeError( - "conflicting values for variable %r on " - "objects to be combined:\n" - "first value: %r\nsecond value: %r" % (name, out, var) - ) - if combine_method: - # TODO: add preservation of attrs into fillna - out = getattr(out, combine_method)(var) - out.attrs = var.attrs + out = getattr(out, combine_method)(var) return out @@ -152,7 +163,7 @@ def merge_variables( priority_vars : mapping with Variable or None values, optional If provided, variables are always taken from this dict in preference to the input variable dictionaries, without checking for conflicts. - compat : {'identical', 'equals', 'broadcast_equals', 'minimal', 'no_conflicts'}, optional + compat : {'identical', 'equals', 'broadcast_equals', 'minimal', 'no_conflicts', 'override'}, optional Type of equality check to use when checking for conflicts. Returns @@ -449,7 +460,7 @@ def merge_core( ---------- objs : list of mappings All values must be convertable to labeled arrays. - compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts'}, optional + compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional Compatibility checks to use when merging variables. join : {'outer', 'inner', 'left', 'right'}, optional How to combine objects with different indexes. @@ -519,7 +530,7 @@ def merge(objects, compat="no_conflicts", join="outer", fill_value=dtypes.NA): objects : Iterable[Union[xarray.Dataset, xarray.DataArray, dict]] Merge together all variables from these objects. If any of them are DataArray objects, they must have a name. - compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts'}, optional + compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional String indicating how to compare variables of the same name for potential conflicts: @@ -531,6 +542,7 @@ def merge(objects, compat="no_conflicts", join="outer", fill_value=dtypes.NA): - 'no_conflicts': only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. + - 'override': skip comparing and pick variable from first dataset join : {'outer', 'inner', 'left', 'right', 'exact'}, optional String indicating how to combine differing indexes in objects. diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index f786a851e62..1abca30d199 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -327,13 +327,13 @@ class TestCheckShapeTileIDs: def test_check_depths(self): ds = create_test_data(0) combined_tile_ids = {(0,): ds, (0, 1): ds} - with raises_regex(ValueError, "sub-lists do not have " "consistent depths"): + with raises_regex(ValueError, "sub-lists do not have consistent depths"): _check_shape_tile_ids(combined_tile_ids) def test_check_lengths(self): ds = create_test_data(0) combined_tile_ids = {(0, 0): ds, (0, 1): ds, (0, 2): ds, (1, 0): ds, (1, 1): ds} - with raises_regex(ValueError, "sub-lists do not have " "consistent lengths"): + with raises_regex(ValueError, "sub-lists do not have consistent lengths"): _check_shape_tile_ids(combined_tile_ids) @@ -565,11 +565,6 @@ def test_combine_concat_over_redundant_nesting(self): expected = Dataset({"x": [0]}) assert_identical(expected, actual) - def test_combine_nested_but_need_auto_combine(self): - objs = [Dataset({"x": [0, 1]}), Dataset({"x": [2], "wall": [0]})] - with raises_regex(ValueError, "cannot be combined"): - combine_nested(objs, concat_dim="x") - @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) def test_combine_nested_fill_value(self, fill_value): datasets = [ @@ -618,7 +613,7 @@ def test_combine_by_coords(self): assert_equal(actual, expected) objs = [Dataset({"x": 0}), Dataset({"x": 1})] - with raises_regex(ValueError, "Could not find any dimension " "coordinates"): + with raises_regex(ValueError, "Could not find any dimension coordinates"): combine_by_coords(objs) objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})] @@ -761,7 +756,7 @@ def test_auto_combine(self): auto_combine(objs) objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})] - with pytest.raises(KeyError): + with raises_regex(ValueError, "'y' is not present in all datasets"): auto_combine(objs) def test_auto_combine_previously_failed(self): diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index ee99ca027d9..00428f70966 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -5,8 +5,7 @@ import pytest from xarray import DataArray, Dataset, Variable, concat -from xarray.core import dtypes - +from xarray.core import dtypes, merge from . import ( InaccessibleArray, assert_array_equal, @@ -18,6 +17,34 @@ from .test_dataset import create_test_data +def test_concat_compat(): + ds1 = Dataset( + { + "has_x_y": (("y", "x"), [[1, 2]]), + "has_x": ("x", [1, 2]), + "no_x_y": ("z", [1, 2]), + }, + coords={"x": [0, 1], "y": [0], "z": [-1, -2]}, + ) + ds2 = Dataset( + { + "has_x_y": (("y", "x"), [[3, 4]]), + "has_x": ("x", [1, 2]), + "no_x_y": (("q", "z"), [[1, 2]]), + }, + coords={"x": [0, 1], "y": [1], "z": [-1, -2], "q": [0]}, + ) + + result = concat([ds1, ds2], dim="y", data_vars="minimal", compat="broadcast_equals") + assert_equal(ds2.no_x_y, result.no_x_y.transpose()) + + for var in ["has_x", "no_x_y"]: + assert "y" not in result[var] + + with raises_regex(ValueError, "'q' is not present in all datasets"): + concat([ds1, ds2], dim="q", data_vars="all", compat="broadcast_equals") + + class TestConcatDataset: @pytest.fixture def data(self): @@ -92,7 +119,7 @@ def test_concat_coords(self): actual = concat(objs, dim="x", coords=coords) assert_identical(expected, actual) for coords in ["minimal", []]: - with raises_regex(ValueError, "not equal across"): + with raises_regex(merge.MergeError, "conflicting values"): concat(objs, dim="x", coords=coords) def test_concat_constant_index(self): @@ -103,8 +130,10 @@ def test_concat_constant_index(self): for mode in ["different", "all", ["foo"]]: actual = concat([ds1, ds2], "y", data_vars=mode) assert_identical(expected, actual) - with raises_regex(ValueError, "not equal across datasets"): - concat([ds1, ds2], "y", data_vars="minimal") + with raises_regex(merge.MergeError, "conflicting values"): + # previously dim="y", and raised error which makes no sense. + # "foo" has dimension "y" so minimal should concatenate it? + concat([ds1, ds2], "new_dim", data_vars="minimal") def test_concat_size0(self): data = create_test_data() @@ -134,6 +163,14 @@ def test_concat_errors(self): data = create_test_data() split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] + with raises_regex(ValueError, "must supply at least one"): + concat([], "dim1") + + with raises_regex(ValueError, "Cannot specify both .*='different'"): + concat( + [data, data], dim="concat_dim", data_vars="different", compat="override" + ) + with raises_regex(ValueError, "must supply at least one"): concat([], "dim1") @@ -146,7 +183,7 @@ def test_concat_errors(self): concat([data0, data1], "dim1", compat="identical") assert_identical(data, concat([data0, data1], "dim1", compat="equals")) - with raises_regex(ValueError, "encountered unexpected"): + with raises_regex(ValueError, "present in some datasets"): data0, data1 = deepcopy(split_data) data1["foo"] = ("bar", np.random.randn(10)) concat([data0, data1], "dim1") diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index d105765481e..76b3ed1a8d6 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -825,7 +825,6 @@ def kernel(name): """Dask kernel to test pickling/unpickling and __repr__. Must be global to make it pickleable. """ - print("kernel(%s)" % name) global kernel_call_count kernel_call_count += 1 return np.ones(1, dtype=np.int64) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index ed1453ce95d..c1e6c7a5ce8 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -196,6 +196,8 @@ def test_merge_compat(self): with raises_regex(ValueError, "compat=.* invalid"): ds1.merge(ds2, compat="foobar") + assert ds1.identical(ds1.merge(ds2, compat="override")) + def test_merge_auto_align(self): ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) ds2 = xr.Dataset({"b": ("x", [3, 4]), "x": [1, 2]})