Skip to content

Commit

Permalink
Added tests for get_network_distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Sep 22, 2023
1 parent 611d0e0 commit bfdb265
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions tests/test_networklength.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import numpy as np
from shapely import LineString, Polygon
from sleap_roots import Series
from sleap_roots.convhull import get_chull_area, get_convhull
from sleap_roots.lengths import get_max_length_pts, get_root_lengths
Expand Down Expand Up @@ -169,6 +170,103 @@ def test_get_network_solidity_rice(rice_h5):
np.testing.assert_almost_equal(ratio, 0.03366254601775008, decimal=7)


def test_get_network_distribution_one_point():
# Define inputs
primary_pts = np.array([[[1, 1], [2, 2], [3, 3]]])
lateral_pts = np.array(
[[[4, 4], [5, 5]], [[6, 6], [np.nan, np.nan]]]
) # One of the roots has only one point
bounding_box = (0, 0, 10, 10)
fraction = 2 / 3
monocots = False

# Call the function
network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, fraction, monocots
)

# Define the expected result
# Only the valid roots should be considered in the calculation
lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(primary_pts[0]).intersection(lower_box).length
+ LineString(lateral_pts[0]).intersection(lower_box).length
)

# Assert that the result is as expected
assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_empty_arrays():
primary_pts = np.full((2, 2), np.nan)
lateral_pts = np.full((2, 2, 2), np.nan)
bounding_box = (0, 0, 10, 10)

network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box)
assert network_length == 0


def test_get_network_distribution_with_nans():
primary_pts = np.array([[1, 1], [2, 2], [np.nan, np.nan]])
lateral_pts = np.array([[[4, 4], [5, 5], [np.nan, np.nan]]])
bounding_box = (0, 0, 10, 10)

network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box)

lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(primary_pts[:-1]).intersection(lower_box).length
+ LineString(lateral_pts[0, :-1]).intersection(lower_box).length
)

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_monocots():
primary_pts = np.array([[1, 1], [2, 2], [3, 3]])
lateral_pts = np.array([[[4, 4], [5, 5]]])
bounding_box = (0, 0, 10, 10)
monocots = True

network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, monocots=monocots
)

lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(lateral_pts[0]).intersection(lower_box).length
) # Only lateral_pts are considered

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_different_fraction():
primary_pts = np.array([[1, 1], [2, 2], [3, 3]])
lateral_pts = np.array([[[4, 4], [5, 5]]])
bounding_box = (0, 0, 10, 10)
fraction = 0.5

network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, fraction=fraction
)

lower_box = Polygon(
[(0, 10 - 10 * fraction), (0, 10), (10, 10), (10, 10 - 10 * fraction)]
)
expected_length = (
LineString(primary_pts).intersection(lower_box).length
+ LineString(lateral_pts[0]).intersection(lower_box).length
)

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution(canola_h5):
series = Series.load(
canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes"
Expand Down

0 comments on commit bfdb265

Please sign in to comment.