Skip to content

Commit

Permalink
add a test of cli interface for postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Sep 11, 2023
1 parent ff85a87 commit 954dc96
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 62 deletions.
Empty file added tests/__init__.py
Empty file.
95 changes: 95 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import cftime
import numpy as np
import pytest
import xarray as xr

grid_latitude = xr.Variable(["grid_latitude"], np.linspace(-3, 3, 13), attrs={})

grid_longitude = xr.Variable(["grid_longitude"], np.linspace(-4, 4, 17), attrs={})

time = xr.Variable(
["time"],
xr.cftime_range(
cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True),
periods=10,
freq="D",
),
)
time_bnds_values = xr.cftime_range(
cftime.Datetime360Day(1980, 12, 1, 0, 0, 0, 0, has_year_zero=True),
periods=len(time) + 1,
freq="D",
).values
time_bnds_pairs = np.concatenate(
[time_bnds_values[:-1, np.newaxis], time_bnds_values[1:, np.newaxis]], axis=1
)

time_bnds = xr.Variable(["time", "bnds"], time_bnds_pairs, attrs={})


@pytest.fixture
def samples_set() -> xr.Dataset:
"""Create a dummy Dataset that looks like a set of samples from the emulator."""

ensemble_member = xr.Variable(["ensemble_member"], np.array(["01"]))

coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"pred_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
),
),
"time_bnds": time_bnds,
}

ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)

return ds


@pytest.fixture
def dataset() -> xr.Dataset:
"""Create a dummy Dataset representing a split of a set of data for training and sampling."""

ensemble_member = xr.Variable(["ensemble_member"], np.array(["01", "02", "03"]))

coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"linpr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
),
),
"target_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
),
),
"time_bnds": time_bnds,
}

ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)

return ds
61 changes: 61 additions & 0 deletions tests/ml_downscaling_emulator/bin/test_postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import shortuuid
from typer.testing import CliRunner

from mlde_utils import samples_path

from ml_downscaling_emulator.bin import app

runner = CliRunner()


def test_filter(tmp_path, samples_file):
time_period = "historic"
workdir = tmp_path / "test-model"
checkpoint = "epoch-1"
ensemble_member = "01"
dataset = "test-dataset"

result = runner.invoke(
app,
[
"postprocess",
"filter",
str(workdir),
"--dataset",
dataset,
"--time-period",
time_period,
"--checkpoint",
checkpoint,
"--ensemble-member",
ensemble_member,
],
)

assert result.exit_code == 0


@pytest.fixture
def samples_file(tmp_path, samples_set):
workdir = tmp_path / "test-model"
checkpoint = "epoch-1"
input_xfm = "stan"
split = "val"
ensemble_member = "01"
dataset = "test-dataset"

dirpath = samples_path(
workdir=workdir,
checkpoint=checkpoint,
input_xfm=input_xfm,
dataset=dataset,
split=split,
ensemble_member=ensemble_member,
)
filepath = dirpath / f"predictions-{shortuuid.uuid()}.nc"

dirpath.mkdir(parents=True, exist_ok=True)

samples_set.to_netcdf(filepath)
return filepath
62 changes: 0 additions & 62 deletions tests/ml_downscaling_emulator/deterministic/test_sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import cftime
import numpy as np
import pytest
import xarray as xr

from ml_downscaling_emulator.deterministic.sampling import sample_id
Expand All @@ -16,62 +13,3 @@ def test_sample_id(dataset: xr.Dataset):
assert (xr_samples["pred_pr"].values == em_dataset["linpr"].values).all()
for dim in ["time", "grid_latitude", "grid_longitude"]:
assert (xr_samples[dim].values == em_dataset[dim].values).all()


@pytest.fixture
def dataset() -> xr.Dataset:
"""Create a dummy Dataset that can be used for sampling."""

grid_latitude = xr.Variable(["grid_latitude"], np.linspace(-3, 3, 13), attrs={})

grid_longitude = xr.Variable(["grid_longitude"], np.linspace(-4, 4, 17), attrs={})

time = xr.Variable(
["time"],
xr.cftime_range(
cftime.Datetime360Day(1980, 12, 1, 12, 0, 0, 0, has_year_zero=True),
periods=10,
freq="D",
),
)
time_bnds_values = xr.cftime_range(
cftime.Datetime360Day(1980, 12, 1, 0, 0, 0, 0, has_year_zero=True),
periods=len(time) + 1,
freq="D",
).values
time_bnds_pairs = np.concatenate(
[time_bnds_values[:-1, np.newaxis], time_bnds_values[1:, np.newaxis]], axis=1
)

time_bnds = xr.Variable(["time", "bnds"], time_bnds_pairs, attrs={})
ensemble_member = xr.Variable(["ensemble_member"], np.array(["01", "02", "03"]))

coords = {
"ensemble_member": ensemble_member,
"time": time,
"grid_latitude": grid_latitude,
"grid_longitude": grid_longitude,
}

data_vars = {
"linpr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
),
),
"target_pr": xr.Variable(
["ensemble_member", "time", "grid_latitude", "grid_longitude"],
np.random.rand(
len(ensemble_member), len(time), len(grid_latitude), len(grid_longitude)
),
),
"time_bnds": time_bnds,
}

ds = xr.Dataset(
data_vars=data_vars,
coords=coords,
)

return ds

0 comments on commit 954dc96

Please sign in to comment.