Skip to content

Commit

Permalink
Exact test case
Browse files Browse the repository at this point in the history
  • Loading branch information
Wh1isper committed Dec 18, 2023
1 parent 6fcfc87 commit 9c22b0d
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions tests/optmize/test_ndarry_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,46 @@


@pytest.fixture
def ndarray_loader(tmp_path):
def ndarray_loader(tmp_path, ndarray_list):
cache_dir = tmp_path / "ndarrycache"
loader = NDArrayLoader(cache_dir=cache_dir)
for ndarray in ndarray_list:
loader.store(ndarray)
yield loader
loader.cleanup()


def test_ndarray_loader(ndarray_loader: NDArrayLoader):
ndarray_list = [
np.array([[1], [2], [3]]),
np.array([[4], [5], [6]]),
np.array([[7], [8], [9]]),
]
@pytest.fixture
def ndarray_list():
"""
1, 4, 7
2, 5, 8
3, 6, 9
"""
yield [
np.array([[1], [2], [3]]),
np.array([[4], [5], [6]]),
np.array([[7], [8], [9]]),
]

ndarray_all = np.concatenate(ndarray_list, axis=1)

for ndarray in ndarray_list:
ndarray_loader.store(ndarray)
def test_ndarray_loader_function(ndarray_loader: NDArrayLoader, ndarray_list):
ndarray_all = np.concatenate(ndarray_list, axis=1)

for i, ndarray in enumerate(ndarray_loader.iter()):
np.testing.assert_equal(ndarray, ndarray_list[i])
np.testing.assert_equal(ndarray_loader.get_all(), ndarray_all)

assert ndarray_loader.shape == ndarray_all.shape


def test_ndarray_loader_slice(ndarray_loader: NDArrayLoader, ndarray_list):
ndarray_all = np.concatenate(ndarray_list, axis=1)

np.testing.assert_equal(ndarray_loader[:], ndarray_all[:])
np.testing.assert_equal(ndarray_loader[::], ndarray_all[::])
np.testing.assert_equal(ndarray_loader[:, :], ndarray_all[:, :])
np.testing.assert_equal(ndarray_loader[::, ::], ndarray_all[::, ::])
np.testing.assert_equal(ndarray_loader[:, 1], ndarray_all[:, 1])
np.testing.assert_equal(ndarray_loader[1, :], ndarray_all[1, :])

Expand All @@ -52,7 +60,12 @@ def test_ndarray_loader(ndarray_loader: NDArrayLoader):
5
6
"""

np.testing.assert_equal(ndarray_loader[1:3, 1:3], ndarray_all[1:3, 1:3])
"""
5, 6
8, 9
"""


if __name__ == "__main__":
Expand Down

0 comments on commit 9c22b0d

Please sign in to comment.