Skip to content

Commit

Permalink
Test case / code fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
emdavis02 committed Jun 12, 2024
1 parent 83f444a commit b0ac880
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 18 deletions.
50 changes: 33 additions & 17 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5313,11 +5313,14 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:

# Provider list to accomodate multiple video inputs
provider_list = []
tmp_data_path_list = []
for data_path_file in data_path_list:
data_path_file = Path(data_path_file)
# Create a provider for each file
if data_path_file.endswith(".slp"):
labels = sleap.load_file(data_path_file)

if data_path_file.as_posix().endswith(".slp"):
print(f"Sleap file: {data_path_file}")
labels = sleap.load_file(data_path_file.as_posix())

if args.only_labeled_frames:
provider_list.append(LabelsReader.from_user_labeled_frames(labels))
elif args.only_suggested_frames:
Expand All @@ -5329,17 +5332,26 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
))
else:
provider_list.append(LabelsReader(labels))

tmp_data_path_list.append(data_path_file)

else:
print(f"Video: {data_path_file}")
# TODO: Clean this up.
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
provider_list.append(VideoReader.from_filepath(
filename=data_path_file, example_indices=frame_list(args.frames), **video_kwargs
))
try:
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
provider_list.append(VideoReader.from_filepath(
filename=data_path_file.as_posix(), example_indices=frame_list(args.frames), **video_kwargs
))
print(f"Video: {data_path_file}")
tmp_data_path_list.append(data_path_file)
# TODO: Clean this up.
except Exception as e:
print(f"Error reading file: {data_path_file}")

data_path_list = tmp_data_path_list


return provider_list, data_path_list

Expand Down Expand Up @@ -5433,7 +5445,7 @@ def main(args: Optional[list] = None):
# Parse inputs.
args, _ = parser.parse_known_args(args)
print("Args:")
pprint(vars(args))
print(vars(args))
print()

# Setup devices.
Expand Down Expand Up @@ -5498,13 +5510,17 @@ def main(args: Optional[list] = None):
for data_path, provider in zip(data_path_list, provider_list):
# Setup models.
predictor = _make_predictor_from_cli(args)
print(f"predictor.tracker: {tracker}")
predictor.tracker = tracker

# Run inference!
labels_pr = predictor.predict(provider)

if output_path is None:
output_path = data_path + ".predictions.slp"
#if data_path.as_posix().endswith(".slp"):
# output_path = data_path
#else:
output_path = data_path.parent / (data_path.stem + ".predictions.slp")


labels_pr.provenance["model_paths"] = predictor.model_paths
Expand All @@ -5524,8 +5540,8 @@ def main(args: Optional[list] = None):
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["data_path"] = data_path.as_posix()
labels_pr.provenance["output_path"] = output_path.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp
Expand All @@ -5552,7 +5568,7 @@ def main(args: Optional[list] = None):
for data_path, provider in zip(data_path_list, provider_list):
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(args.data_path)
labels_pr = sleap.load_file(args.data_path.as_posix())
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)

print("Starting tracker...")
Expand Down
64 changes: 63 additions & 1 deletion tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import zipfile
from pathlib import Path
from typing import cast
import os

import numpy as np
import pytest
Expand Down Expand Up @@ -63,6 +64,14 @@

# sleap.nn.system.use_cpu_only()

@pytest.fixture
def test_sleap_track_mult_inputs_folder_slp_files():
return "tests/data/videos/slp_multiple_inputs"

@pytest.fixture
def test_sleap_track_mult_inputs_folder():
return "tests/data/videos/multiple_inputs"


@pytest.fixture
def test_labels():
Expand Down Expand Up @@ -1447,7 +1456,7 @@ def test_make_predictor_from_cli(
assert predictor.max_instances == 5


def test_sleap_track(
def test_sleap_track_single_input(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
Expand All @@ -1473,6 +1482,59 @@ def test_sleap_track(
args = [slp_path, "--cpu"]
with pytest.raises(ValueError):
sleap_track(args=args)

#@pytest.mark.parametrize("tracking", ["simple", "flow", "simplemaxtracks", "flowmaxtracks", "None"])
def test_sleap_track_mult_input(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
#test_sleap_track_mult_inputs_folder: str,
test_sleap_track_mult_inputs_folder_slp_files: str,
#tracking
):
slp_path = test_sleap_track_mult_inputs_folder_slp_files
#slp_path = test_sleap_track_mult_inputs_folder
slp_path_obj = Path(slp_path)
Labels.save(centered_pair_predictions, slp_path)

# Create sleap-track command
args = (
f"{slp_path} --model {min_centroid_model_path} "
#f"--tracking.tracker {tracking} "
f"--tracking.tracker simple "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

if Path.is_dir(slp_path_obj):
slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
slp_path_list.append(file_path)
elif Path.is_file(slp_path_obj):
slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()

# Run inference
sleap_track(args=args)
slp_path = Path(slp_path)

# Assert predictions file exists
if Path.is_dir(slp_path_obj):
new_slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
new_slp_path_list.append(file_path)
elif Path.is_file(slp_path):
new_slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()

files_to_remove = set(new_slp_path_list) - set(slp_path_list)
for file in files_to_remove:
file.unlink()


def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
Expand Down

0 comments on commit b0ac880

Please sign in to comment.