Skip to content

Commit

Permalink
remove unnecessary for loop, calculate proper displacement, adjusted …
Browse files Browse the repository at this point in the history
…tests accordingly
  • Loading branch information
grquach committed Jul 17, 2024
1 parent bee834d commit 8cc046c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 54 deletions.
6 changes: 3 additions & 3 deletions sleap/config/suggestions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ main:
type: double
default: 0.1
range: 0.1,1.0

"max point displacement":
- name: per_video
label: Threshold
- name: displacement_threshold
label: Maximum Displacement Threshold
type: int
default: 10

Expand Down
63 changes: 16 additions & 47 deletions sleap/gui/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,56 +319,25 @@ def max_point_displacement(
def _max_point_displacement_video(
cls, video: Video, labels: "Labels", displacement_threshold: float
):
# ONCE labels.numpy works: delete lfs ~322 - 328
lfs = labels.find(video)
frames = len(lfs)

if frames < 2:
return []
# Get numpy of shape (frames, tracks, nodes, x, y)
labels_numpy = labels.numpy(video=video, all_frames=True, untracked=False)


video_instances = labels.numpy(video=video, all_frames=True, untracked=False)
frames = len(video_instances)

if frames < 2:
# Return empty list if not enough frames
n_frames, n_tracks, n_nodes, _ = labels_numpy.shape
if n_frames < 2:
return []

# Calculate displacements
diff = labels_numpy[1:] - labels_numpy[:-1] # (frames - 1, tracks, nodes, x, y)
euc_norm = np.linalg.norm(diff, axis=-1) # (frames - 1, tracks, nodes)
mean_euc_norm = np.nanmean(euc_norm, axis=-1) # (frames - 1, tracks)

# Find frames where mean displacement is above threshold
threshold_mask = np.any(
mean_euc_norm > displacement_threshold, axis=-1
) # (frames - 1,)
frame_idxs = list(np.argwhere(threshold_mask).flatten()) # [0, len(frames - 1)]

# ONCE labels.numpy works: delete print statements ~336 - 340
print('type of video_instances: ', type(video_instances))
print(video_instances[0])
print('type of video_instances[0]: ', type(video_instances[0]))
print(f"Number of elements returned by labels.numpy(): {video_instances.shape}")
print(f"Number of elements returned by labels.numpy(): {len(video_instances)}")


print('type of video_instances: ', type(video_instances))
print('type of video_instances[0]: ', type(video_instances[0]))


displacements = []
for idx in range(1, frames):
prev_points = video_instances[idx-1]
curr_points = video_instances[idx]


if prev_points.shape != curr_points.shape:
continue

# Mask to identify non-nan values
valid_mask = ~np.isnan(prev_points) & ~np.isnan(curr_points)
# Filter out nan values
valid_prev_points = prev_points[valid_mask].reshape(-1, 2)
valid_curr_points = curr_points[valid_mask].reshape(-1, 2)

if valid_prev_points.size == 0 or valid_curr_points.size == 0:
continue

displacement = np.linalg.norm(valid_curr_points - valid_prev_points, axis=1).sum()
displacements.append((displacement, idx))

frame_idxs = [
frame_idx for displacement, frame_idx in displacements if displacement > displacement_threshold
]

return cls.idx_list_to_frame_list(frame_idxs, video)

Expand Down
11 changes: 7 additions & 4 deletions tests/gui/test_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def test_max_point_displacement_suggestions(centered_pair_predictions):
params=dict(
videos=centered_pair_predictions.videos,
method="max_point_displacement",
displacement_threshold = 300
displacement_threshold = 6
),
)
assert len(suggestions) == 6
assert suggestions[0].frame_idx == 2117
assert suggestions[1].frame_idx == 4937
assert len(suggestions) == 19
assert suggestions[0].frame_idx == 27
assert suggestions[1].frame_idx == 81

def test_frame_increment(centered_pair_predictions: Labels):
# Testing videos that have less frames than desired Samples per Video (stride)
Expand Down Expand Up @@ -521,3 +521,6 @@ def check_all_predicted_instances(sugg, labels):
},
)
assert_suggestions_unique(labels, suggestions)

if __name__=="__main__":
pytest.main([f"{__file__}::test_max_point_displacement_suggestions"])

0 comments on commit 8cc046c

Please sign in to comment.