Skip to content

Commit

Permalink
Various small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Apr 7, 2019
1 parent eb50f50 commit 2ce0639
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 17 deletions.
2 changes: 1 addition & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _update_coords(self, coords):

self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._dims = dict(dims)
self._data._dims = dims
self._data._indexes = None

def __delitem__(self, key):
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _infer_coords_and_dims(shape, coords, dims):
for dim, coord in zip(dims, coords):
var = as_variable(coord, name=dim)
var.dims = (dim,)
new_coords[dim] = var
new_coords[dim] = var.to_index_variable()

sizes = dict(zip(dims, shape))
for k, v in new_coords.items():
Expand Down
30 changes: 18 additions & 12 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def calculate_dimensions(variables):
Returns dictionary mapping from dimension names to sizes. Raises ValueError
if any of the dimension sizes conflict.
"""
dims = OrderedDict()
dims = {}
last_used = {}
scalar_vars = set(k for k, v in variables.items() if not v.dims)
for k, var in variables.items():
Expand Down Expand Up @@ -693,7 +693,7 @@ def _construct_direct(cls, variables, coord_names, dims, attrs=None,

@classmethod
def _from_vars_and_coord_names(cls, variables, coord_names, attrs=None):
dims = dict(calculate_dimensions(variables))
dims = calculate_dimensions(variables)
return cls._construct_direct(variables, coord_names, dims, attrs)

# TODO(shoyer): renable type checking on this signature when pytype has a
Expand Down Expand Up @@ -754,18 +754,20 @@ def _replace_with_new_dims( # type: ignore
coord_names: set = None,
attrs: 'Optional[OrderedDict]' = __default,
indexes: 'Optional[OrderedDict[Any, pd.Index]]' = __default,
encoding: Optional[dict] = __default,
inplace: bool = False,
) -> T:
"""Replace variables with recalculated dimensions."""
dims = dict(calculate_dimensions(variables))
dims = calculate_dimensions(variables)
return self._replace(
variables, coord_names, dims, attrs, indexes, inplace=inplace)
variables, coord_names, dims, attrs, indexes, encoding,
inplace=inplace)

def _replace_vars_and_dims( # type: ignore
self: T,
variables: 'OrderedDict[Any, Variable]' = None,
coord_names: set = None,
dims: 'OrderedDict[Any, int]' = None,
dims: Dict[Any, int] = None,
attrs: 'Optional[OrderedDict]' = __default,
inplace: bool = False,
) -> T:
Expand Down Expand Up @@ -1081,6 +1083,7 @@ def __delitem__(self, key):
"""
del self._variables[key]
self._coord_names.discard(key)
self._dims = calculate_dimensions(self._variables)

# mutable objects should not be hashable
# https://github.com/python/mypy/issues/4266
Expand Down Expand Up @@ -2463,7 +2466,7 @@ def expand_dims(self, dim=None, axis=None, **dim_kwargs):
else:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
variables[k] = v.set_dims(k)
variables[k] = v.set_dims(k).to_index_variable()

new_dims = self._dims.copy()
new_dims.update(dim)
Expand Down Expand Up @@ -3548,12 +3551,15 @@ def from_dict(cls, d):
def _unary_op(f, keep_attrs=False):
@functools.wraps(f)
def func(self, *args, **kwargs):
ds = self.coords.to_dataset()
for k in self.data_vars:
ds._variables[k] = f(self._variables[k], *args, **kwargs)
if keep_attrs:
ds._attrs = self._attrs
return ds
variables = OrderedDict()
for k, v in self._variables.items():
if k in self._coord_names:
variables[k] = v
else:
variables[k] = f(v, *args, **kwargs)
attrs = self._attrs if keep_attrs else None
return self._replace_with_new_dims(
variables, attrs=attrs, encoding=None)

return func

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def merge_core(objs,
'coordinates or not in the merged result: %s'
% ambiguous_coords)

return variables, coord_names, dict(dims)
return variables, coord_names, dims


def merge(objects, compat='no_conflicts', join='outer'):
Expand Down
4 changes: 3 additions & 1 deletion xarray/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def _assert_dataset_invariants(ds: Dataset):

assert type(ds._dims) is dict, ds._dims
assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims
var_dims = set.union(*[set(v.dims) for v in ds._variables.values()])
var_dims = set() # type: set
for v in ds._variables.values():
var_dims.update(v.dims)
assert ds._dims.keys() == var_dims, (set(ds._dims), var_dims)
assert all(ds._dims[k] == v.sizes[k]
for v in ds._variables.values()
Expand Down
1 change: 0 additions & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def source_ndarray(array):

# Internal versions of xarray's test functions that validate additional
# invariants
# TODO: add more invariant checks.

def assert_equal(a, b, *, check_invariants=True):
xarray.testing.assert_equal(a, b)
Expand Down
5 changes: 5 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2677,6 +2677,11 @@ def test_delitem(self):
assert set(data.variables) == all_items - set(['var1', 'numbers'])
assert 'numbers' not in data.coords

expected = Dataset()
actual = Dataset({'y': ('x', [1, 2])})
del actual['y']
assert_identical(expected, actual)

def test_squeeze(self):
data = Dataset({'foo': (['x', 'y', 'z'], [[[1], [2]]])})
for args in [[], [['x']], [['x', 'z']]]:
Expand Down

0 comments on commit 2ce0639

Please sign in to comment.