Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Dataset.query() method, analogous to pandas DataFrame.query() #4984

Merged
merged 16 commits into from
Mar 16, 2021
Merged
1 change: 1 addition & 0 deletions ci/requirements/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- nc-time-axis
- netcdf4
- numba
- numexpr
dcherian marked this conversation as resolved.
Show resolved Hide resolved
- numpy
- pandas
- pint
Expand Down
73 changes: 73 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6980,5 +6980,78 @@ def argmax(self, dim=None, axis=None, **kwargs):
"Dataset.argmin() with a sequence or ... for dim"
)

def query(
self,
queries: Mapping[Hashable, Any] = None,
parser: str = "pandas",
engine: str = None,
missing_dims: str = "raise",
**queries_kwargs: Any,
) -> "Dataset":
"""Return a new dataset with each array indexed along the specified
dimension(s), where the indexers are given as strings containing
Python expressions to be evaluated against the data variables in the
dataset.

Parameters
----------
queries : dict, optional
A dic with keys matching dimensions and values given by strings
containing Python expressions to be evaluated against the data variables
in the dataset. The expressions will be evaluated using the pandas
eval() function, and can contain any valid Python expressions but cannot
contain any Python statements.
parser : {"pandas", "python"}, default: "pandas"
The parser to use to construct the syntax tree from the expression.
The default of 'pandas' parses code slightly different than standard
Python. Alternatively, you can parse an expression using the 'python'
parser to retain strict Python semantics.
engine: {"python", "numexpr", None}, default: None
The engine used to evaluate the expression. Supported engines are:
- None: tries to use numexpr, falls back to python
- "numexpr": evaluates expressions using numexpr
- "python": performs operations as if you had eval’d in top level python
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
What to do if dimensions that should be selected from are not present in the
Dataset:
- "raise": raise an exception
- "warning": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions
**queries_kwargs : {dim: query, ...}, optional
The keyword arguments form of ``queries``.
One of queries or queries_kwargs must be provided.

Returns
-------
obj : Dataset
A new Dataset with the same contents as this dataset, except each
array and dimension is indexed by the results of the appropriate
queries.

See Also
--------
Dataset.isel
pandas.eval

"""

# allow queries to be given either as a dict or as kwargs
queries = either_dict_or_kwargs(queries, queries_kwargs, "query")
alimanfoo marked this conversation as resolved.
Show resolved Hide resolved

# check queries
for dim, expr in queries.items():
if not isinstance(expr, str):
msg = f"expr for dim {dim} must be a string to be evaluated, {type(expr)} given"
raise ValueError(msg)

# evaluate the queries to create the indexers
indexers = {
dim: pd.eval(expr, resolvers=[self], parser=parser, engine=engine)
for dim, expr in queries.items()
}

# apply the selection
return self.isel(indexers, missing_dims=missing_dims)


ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False)
77 changes: 77 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5807,6 +5807,83 @@ def test_astype_attrs(self):
assert not data.astype(float, keep_attrs=False).attrs
assert not data.astype(float, keep_attrs=False).var1.attrs

@pytest.mark.parametrize("parser", ["pandas", "python"])
@pytest.mark.parametrize("engine", ["python", "numexpr", None])
@pytest.mark.parametrize("backend", ["numpy", "dask"])
def test_query(self, backend, engine, parser):
"""Test querying a dataset."""

# setup test data
np.random.seed(42)
a = np.arange(0, 10, 1)
b = np.random.randint(0, 100, size=10)
c = np.linspace(0, 1, 20)
d = np.arange(0, 200).reshape(10, 20)
if backend == "numpy":
ds = Dataset(
{"a": ("x", a), "b": ("x", b), "c": ("y", c), "d": (("x", "y"), d)}
)
elif backend == "dask":
ds = Dataset(
{
"a": ("x", da.from_array(a, chunks=3)),
"b": ("x", da.from_array(b, chunks=3)),
"c": ("y", da.from_array(c, chunks=7)),
"d": (("x", "y"), da.from_array(d, chunks=(3, 7))),
}
)

# query single dim, single variable
actual = ds.query(x="a > 5", engine=engine, parser=parser)
expect = ds.isel(x=(a > 5))
assert_identical(expect, actual)

# query single dim, single variable, via dict
actual = ds.query(dict(x="a > 5"), engine=engine, parser=parser)
expect = ds.isel(dict(x=(a > 5)))
assert_identical(expect, actual)

# query single dim, single variable
actual = ds.query(x="b > 50", engine=engine, parser=parser)
expect = ds.isel(x=(b > 50))
assert_identical(expect, actual)

# query single dim, single variable
actual = ds.query(y="c < .5", engine=engine, parser=parser)
expect = ds.isel(y=(c < 0.5))
assert_identical(expect, actual)

# query single dim, multiple variables
actual = ds.query(x="(a > 5) & (b > 50)", engine=engine, parser=parser)
expect = ds.isel(x=((a > 5) & (b > 50)))
assert_identical(expect, actual)

# support pandas query parser
if parser == "pandas":
actual = ds.query(x="(a > 5) and (b > 50)", engine=engine, parser=parser)
expect = ds.isel(x=((a > 5) & (b > 50)))
assert_identical(expect, actual)

# query multiple dims via kwargs
actual = ds.query(x="a > 5", y="c < .5", engine=engine, parser=parser)
expect = ds.isel(x=(a > 5), y=(c < 0.5))
assert_identical(expect, actual)

# query multiple dims via dict
actual = ds.query(dict(x="a > 5", y="c < .5"), engine=engine, parser=parser)
expect = ds.isel(dict(x=(a > 5), y=(c < 0.5)))
assert_identical(expect, actual)

# test error handling
with pytest.raises(ValueError):
ds.query("a > 5") # must be dict
with pytest.raises(IndexError):
ds.query(y="a > 5") # wrong length dimension
with pytest.raises(IndexError):
ds.query(x="c < .5") # wrong length dimension
with pytest.raises(IndexError):
ds.query(x="d > 100") # wrong number of dimensions


# Py.test tests

Expand Down