Skip to content

Commit

Permalink
create function max_point_displacement, _max_point_displacement_video…
Browse files Browse the repository at this point in the history
…. Add to yaml file. Create test for new function . . . will need to edit
  • Loading branch information
grquach committed Jul 10, 2024
1 parent 324377e commit 2789b61
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 1 deletion.
8 changes: 7 additions & 1 deletion sleap/config/suggestions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ main:
label: Method
type: stacked
default: " "
options: " ,image features,sample,prediction score,velocity,frame chunk"
options: " ,image features,sample,prediction score,velocity,frame chunk,max point displacement"
" ":

sample:
Expand Down Expand Up @@ -175,6 +175,12 @@ main:
type: double
default: 0.1
range: 0.1,1.0

"max point displacement":
- name: per_video
label: Threshold
type: int
default: 10

- name: target
label: Target
Expand Down
52 changes: 52 additions & 0 deletions sleap/gui/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def suggest(cls, params: dict, labels: "Labels" = None) -> List[SuggestionFrame]
prediction_score=cls.prediction_score,
velocity=cls.velocity,
frame_chunk=cls.frame_chunk,
max_point_displacement = cls.max_point_displacement,
)

method = str.replace(params["method"], " ", "_")
Expand Down Expand Up @@ -291,6 +292,57 @@ def _velocity_video(

return cls.idx_list_to_frame_list(frame_idxs, video)

def max_point_displacement(
cls,
labels: "Labels",
videos: List[Video],
displacement_threshold: float,
**kwargs,
):
"""Finds frames with maximum point displacement above a threshold."""

proposed_suggestions = []
for video in videos:
proposed_suggestions.extend(
cls._max_point_displacement_video(video, labels, displacement_threshold)
)

suggestions = VideoFrameSuggestions.filter_unique_suggestions(
labels, videos, proposed_suggestions
)

return suggestions

@classmethod
def _max_point_displacement_video(
cls, video: Video, labels: "Labels", displacement_threshold: float
):
lfs = labels.find(video)
frames = len(lfs)

if frames < 2:
return []

displacements = []
for i in range(1, frames):
prev_lf = lfs[i - 1]
curr_lf = lfs[i]
prev_points = np.array([inst.points_array for inst in prev_lf.instances_to_show])
curr_points = np.array([inst.points_array for inst in curr_lf.instances_to_show])

if prev_points.shape != curr_points.shape:
continue

displacement = np.linalg.norm(curr_points - prev_points, axis=2).sum()
displacements.append((displacement, curr_lf.frame_idx))

frame_idxs = [
frame_idx for displacement, frame_idx in displacements if displacement > displacement_threshold
]

return cls.idx_list_to_frame_list(frame_idxs, video)


@classmethod
def frame_chunk(
cls,
Expand Down
13 changes: 13 additions & 0 deletions tests/gui/test_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def test_velocity_suggestions(centered_pair_predictions):
assert suggestions[0].frame_idx == 21
assert suggestions[1].frame_idx == 45

# something like this?
def test_max_point_displacement_suggestions(centered_pair_predictions):
suggestions = VideoFrameSuggestions.suggest(
labels=centered_pair_predictions,
params=dict(
videos=centered_pair_predictions.videos,
method="max_point_displacement",
displacement_threshold = 3
),
)
assert len(suggestions) == 45
assert suggestions[0].frame_idx == 21
assert suggestions[1].frame_idx == 45

def test_frame_increment(centered_pair_predictions: Labels):
# Testing videos that have less frames than desired Samples per Video (stride)
Expand Down

0 comments on commit 2789b61

Please sign in to comment.