Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset/DataArray arithmethic #637

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions mikeio/dataset/_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ def _parse_geometry(
if dims == ("time", "x"):
return Grid1D(nx=shape[1], dx=1.0 / (shape[1] - 1))

warnings.warn("Geometry is required for ndim >=1")

axis = 1 if "time" in dims else 0
# dims_no_time = tuple([d for d in dims if d != "time"])
# shape_no_time = shape[1:] if ("time" in dims) else shape
Expand Down Expand Up @@ -292,6 +290,9 @@ def _parse_geometry(
assert shape[axis + 1] == geometry.nx, "data shape does not match nx"
# elif isinstance(geometry, Grid3D): # TODO

if geometry is None:
geometry = GeometryUndefined()

return geometry

@staticmethod
Expand Down Expand Up @@ -1669,8 +1670,16 @@ def _apply_math_operation(

# TODO: check if geometry etc match if other is DataArray?

new_da = self.copy() # TODO: alternatively: create new dataset (will validate)
new_da.values = data
# new_da = self.copy() # TODO: alternatively: create new dataset (will validate)
# new_da.values = data

time = self.time
if isinstance(other, DataArray):
time = other.time if len(self.time) == 1 else self.time

new_da = DataArray(
data=data, time=time, geometry=self.geometry, item=self.item, zn=self._zn
)

if not self._keep_EUM_after_math_operation(other, func):
other_name = other.name if hasattr(other, "name") else "array"
Expand Down
24 changes: 9 additions & 15 deletions mikeio/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,17 +1696,12 @@ def __mul__(self, other: "Dataset" | float) -> "Dataset":

def _add_dataset(self, other: "Dataset", sign: float = 1.0) -> "Dataset":
self._check_datasets_match(other)
try:
data = [
self[x].to_numpy() + sign * other[y].to_numpy()
for x, y in zip(self.items, other.items)
]
except TypeError:
raise TypeError("Could not add data in Dataset")
newds = self.copy()
for j in range(len(self)):
newds[j].values = data[j] # type: ignore
return newds
das = []
for da1, da2 in zip(self, other):
da = da1 + sign * da2
das.append(da)

return Dataset(das, validate=False)

def _check_datasets_match(self, other: "Dataset") -> None:
if self.n_items != other.n_items:
Expand All @@ -1722,10 +1717,9 @@ def _check_datasets_match(self, other: "Dataset") -> None:
raise ValueError(
f"Item units must match. Item {j}: {self.items[j].unit} != {other.items[j].unit}"
)
if not np.all(self.time == other.time):
raise ValueError("All timesteps must match")
if self.shape != other.shape:
raise ValueError("shape must match")
if len(self.time) > 1 and len(other.time) > 1:
if not np.all(self.time == other.time):
raise ValueError("All timesteps must match")

def _add_value(self, value: float) -> "Dataset":
try:
Expand Down
102 changes: 48 additions & 54 deletions tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,31 +167,6 @@ def test_data_0d(da0):
assert "values" in repr(da0)


def test_create_data_1d_default_grid():
da = mikeio.DataArray(
data=np.zeros((10, 5)),
time=pd.date_range(start="2000-01-01", freq="h", periods=10),
item=ItemInfo("Foo"),
)
assert isinstance(da.geometry, mikeio.Grid1D)


# def test_data_2d_no_geometry_not_allowed():

# nt = 10
# nx = 7
# ny = 14

# with pytest.warns(Warning) as w:
# mikeio.DataArray(
# data=np.zeros([nt, ny, nx]) + 0.1,
# time=pd.date_range(start="2000-01-01", freq="S", periods=nt),
# item=ItemInfo("Foo"),
# )

# assert "geometry" in str(w[0].message).lower()


def test_dataarray_init():
nt = 10
start = 10.0
Expand Down Expand Up @@ -239,56 +214,57 @@ def test_dataarray_init_2d():

# 2d with time
ny, nx = 5, 6
geometry = mikeio.Grid2D(ny=ny, nx=nx, dx=1.0)
data2d = np.zeros([nt, ny, nx]) + 0.1
da = mikeio.DataArray(data=data2d, time=time)
da = mikeio.DataArray(data=data2d, time=time, geometry=geometry)
assert da.ndim == 3
assert da.dims == ("time", "y", "x")

# singleton time, requires spec of dims
dims = ("time", "y", "x")
data2d = np.zeros([1, ny, nx]) + 0.1
da = mikeio.DataArray(data=data2d, time="2018", dims=dims)
da = mikeio.DataArray(data=data2d, time="2018", dims=dims, geometry=geometry)
assert isinstance(da, mikeio.DataArray)
assert da.n_timesteps == 1
assert da.ndim == 3
assert da.dims == dims

# no time
data2d = np.zeros([ny, nx]) + 0.1
da = mikeio.DataArray(data=data2d, time="2018")
da = mikeio.DataArray(data=data2d, time="2018", geometry=geometry)
assert isinstance(da, mikeio.DataArray)
assert da.n_timesteps == 1
assert da.ndim == 2
assert da.dims == ("y", "x")

# x, y swapped
dims = ("x", "y")
data2d = np.zeros([nx, ny]) + 0.1
da = mikeio.DataArray(data=data2d, time="2018", dims=dims)
assert da.n_timesteps == 1
assert da.ndim == 2
assert da.dims == dims
# # x, y swapped
# dims = ("x", "y")
# data2d = np.zeros([nx, ny]) + 0.1
# da = mikeio.DataArray(data=data2d, time="2018", dims=dims)
# assert da.n_timesteps == 1
# assert da.ndim == 2
# assert da.dims == dims


def test_dataarray_init_5d():
nt = 10
time = pd.date_range(start="2000-01-01", freq="s", periods=nt)
# def test_dataarray_init_5d():
# nt = 10
# time = pd.date_range(start="2000-01-01", freq="S", periods=nt)

# 5d with named dimensions
dims = ("x", "y", "layer", "member", "season")
data5d = np.zeros([2, 4, 5, 3, 3]) + 0.1
da = mikeio.DataArray(data=data5d, time="2018", dims=dims)
assert da.n_timesteps == 1
assert da.ndim == 5
assert da.dims == dims
# # 5d with named dimensions
# dims = ("x", "y", "layer", "member", "season")
# data5d = np.zeros([2, 4, 5, 3, 3]) + 0.1
# da = mikeio.DataArray(data=data5d, time="2018", dims=dims)
# assert da.n_timesteps == 1
# assert da.ndim == 5
# assert da.dims == dims

# 5d with named dimensions and time
dims = ("time", "dummy", "layer", "member", "season")
data5d = np.zeros([nt, 4, 5, 3, 3]) + 0.1
da = mikeio.DataArray(data=data5d, time=time, dims=dims)
assert da.n_timesteps == nt
assert da.ndim == 5
assert da.dims == dims
# # 5d with named dimensions and time
# dims = ("time", "dummy", "layer", "member", "season")
# data5d = np.zeros([nt, 4, 5, 3, 3]) + 0.1
# da = mikeio.DataArray(data=data5d, time=time, dims=dims)
# assert da.n_timesteps == nt
# assert da.ndim == 5
# assert da.dims == dims


def test_dataarray_init_wrong_dim():
Expand Down Expand Up @@ -832,8 +808,9 @@ def test_modify_values_1d(da1):
assert da1.values[4] == 12.0

# values is scalar, therefore copy by definition. Original is not changed.
# TODO is the treatment of scalar sensible, i.e. consistent with xarray?
da1.isel(4).values = 11.0
da1.isel(4).values = (
11.0 # TODO is the treatment of scalar sensible, i.e. consistent with xarray?
)
assert da1.values[4] != 11.0

# fancy indexing will return copy! Original is *not* changed.
Expand Down Expand Up @@ -992,6 +969,23 @@ def test_multiply_two_dataarrays_broadcasting(da_grid2d):
assert da_grid2d.shape == da3.shape


def test_math_broadcasting(da1):
da2 = da1.mean("time")

da3 = da1 - da2
assert isinstance(da3, mikeio.DataArray)
assert da1.shape == da3.shape

da3 = da1 + da2
assert isinstance(da3, mikeio.DataArray)
assert da1.shape == da3.shape

# + is commutative
da4 = da2 + da1
assert isinstance(da4, mikeio.DataArray)
assert da1.shape == da4.shape


def test_math_two_dataarrays(da1):
da3 = da1 + da1
assert isinstance(da3, mikeio.DataArray)
Expand Down
Loading
Loading