Skip to content

Commit

Permalink
test case for output directory
Browse files Browse the repository at this point in the history
  • Loading branch information
emdavis02 committed Jun 18, 2024
1 parent dcc7a63 commit 35db452
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 71 deletions.
12 changes: 9 additions & 3 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5298,12 +5298,12 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:

# Check for multiple video inputs
# Compile file(s) into a list for later itteration
if Path.is_dir(data_path):
if data_path.is_dir:
data_path_list = []
for file_path in data_path.iterdir():
if file_path.is_file():
data_path_list.append(file_path)
elif Path.is_file(data_path):
elif data_path.is_file:
data_path_list = [args.data_path]
else:
raise ValueError(
Expand Down Expand Up @@ -5365,6 +5365,7 @@ def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Returns:
The `Predictor` created from loaded models.
"""
print(args)
peak_threshold = None
for deprecated_arg in [
"single.peak_threshold",
Expand Down Expand Up @@ -5519,6 +5520,10 @@ def main(args: Optional[list] = None):
if output_path is None:

output_path = data_path.parent / (data_path.stem + ".predictions.slp")
output_path_obj = Path(output_path)

else:
output_path = output_path + "/" + (data_path.stem + ".predictions.slp")


labels_pr.provenance["model_paths"] = predictor.model_paths
Expand All @@ -5533,13 +5538,14 @@ def main(args: Optional[list] = None):
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")


# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path.as_posix()
labels_pr.provenance["output_path"] = output_path.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp
Expand Down
180 changes: 112 additions & 68 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
def test_sleap_track_mult_inputs_folder_slp():
return "tests/data/videos/multiple_inputs_slp"

@pytest.fixture
def test_sleap_track_output_folder():
return "tests/data/output_folder"

@pytest.fixture
def test_sleap_track_mult_inputs_folder_slp_mp4():
return "tests/data/videos/multiple_inputs_slp_mp4"
Expand Down Expand Up @@ -1458,6 +1462,37 @@ def test_make_predictor_from_cli(
elif isinstance(predictor, BottomUpPredictor):
assert predictor.max_instances == 5

def test_make_predictor_from_cli_mult_input(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
min_bottomup_model_path: str,
test_sleap_track_mult_inputs_folder_slp: str,
):
slp_path = str(Path(test_sleap_track_mult_inputs_folder_slp))
Labels.save(centered_pair_predictions, slp_path)

# Create sleap-track command
model_args = [
f"--model {min_centroid_model_path} --model {min_centered_instance_model_path}",
f"--model {min_bottomup_model_path}",
]
for model_arg in model_args:
#print(model_arg)
args = (
f"{slp_path} {model_arg} --video.index 0 --frames 1-3 "
"--cpu --max_instances 5"
).split()
parser = _make_cli_parser()
args, _ = parser.parse_known_args(args=args)

# Create predictor
predictor = _make_predictor_from_cli(args=args)
if isinstance(predictor, TopDownPredictor):
assert predictor.inference_model.centroid_crop.max_instances == 5
elif isinstance(predictor, BottomUpPredictor):
assert predictor.max_instances == 5


def test_sleap_track_single_input(
centered_pair_predictions: Labels,
Expand Down Expand Up @@ -1488,7 +1523,6 @@ def test_sleap_track_single_input(

@pytest.mark.parametrize("tracking", ["simple", "flow", "None"])
def test_sleap_track_mult_input_slp(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
test_sleap_track_mult_inputs_folder_slp: str,
Expand All @@ -1504,41 +1538,29 @@ def test_sleap_track_mult_input_slp(
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

if Path.is_dir(slp_path_obj):
slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
slp_path_list.append(file_path)
elif Path.is_file(slp_path_obj):
slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()
slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()]

# Run inference
sleap_track(args=args)
slp_path = Path(slp_path)

# Assert predictions file exists
if Path.is_dir(slp_path_obj):
new_slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
new_slp_path_list.append(file_path)
elif Path.is_file(slp_path):
new_slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()
expected_extensions = {'.slp', '.mp4', '.avi'} # Add other video formats if necessary

for file_path in slp_path_list:
if file_path.suffix in expected_extensions:
expected_output_file = file_path.parent / (file_path.stem + ".predictions.slp")
print(f"PATH: {expected_output_file}")
assert Path(expected_output_file).exists()

new_slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()]

files_to_remove = set(new_slp_path_list) - set(slp_path_list)
for file in files_to_remove:
file.unlink()


@pytest.mark.parametrize("tracking", ["simple", "flow", "None"])
def test_sleap_track_mult_input_slp_mp4(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
test_sleap_track_mult_inputs_folder_slp_mp4: str,
Expand All @@ -1554,86 +1576,108 @@ def test_sleap_track_mult_input_slp_mp4(
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

if Path.is_dir(slp_path_obj):
slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
slp_path_list.append(file_path)
elif Path.is_file(slp_path_obj):
slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()
slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()]

# Run inference
sleap_track(args=args)
slp_path = Path(slp_path)

# Assert predictions file exists
if Path.is_dir(slp_path_obj):
new_slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
new_slp_path_list.append(file_path)
elif Path.is_file(slp_path):
new_slp_path_list = [args.data_path]

for output_path in slp_path_list:
assert Path(output_path).exists()
expected_extensions = {'.slp', '.mp4', '.avi'} # Add other video formats if necessary

for file_path in slp_path_list:
if file_path.suffix in expected_extensions:
expected_output_file = file_path.parent / (file_path.stem + ".predictions.slp")
print(f"PATH: {expected_output_file}")
assert Path(expected_output_file).exists()

new_slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()]

files_to_remove = set(new_slp_path_list) - set(slp_path_list)
for file in files_to_remove:
file.unlink()

#@pytest.mark.parametrize("tracking", ["simple", "flow", "None"])
@pytest.mark.parametrize("tracking", ["simple", "flow", "None"])
def test_sleap_track_mult_input_mp4(
centered_pair_predictions: Labels,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
test_sleap_track_mult_inputs_folder_mp4: str,
#tracking
tracking
):
slp_path = test_sleap_track_mult_inputs_folder_mp4
slp_path_obj = Path(slp_path)

# Create sleap-track command
args = (
f"{slp_path} --model {min_centroid_model_path} "
#f"--tracking.tracker {tracking} "
f"--tracking.tracker {tracking} "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()]

# Run inference
sleap_track(args=args)
slp_path = Path(slp_path)

# Assert predictions file exists
expected_extensions = {'.slp', '.mp4', '.avi'} # Add other video formats if necessary

for file_path in slp_path_list:
if file_path.suffix in expected_extensions:
expected_output_file = file_path.parent / (file_path.stem + ".predictions.slp")
print(f"PATH: {expected_output_file}")
assert Path(expected_output_file).exists()

new_slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()]

files_to_remove = set(new_slp_path_list) - set(slp_path_list)
for file in files_to_remove:
file.unlink()

def test_sleap_track_output_mult(
test_sleap_track_output_folder: str,
min_centroid_model_path: str,
min_centered_instance_model_path: str,
test_sleap_track_mult_inputs_folder_mp4: str,
):
slp_path = test_sleap_track_mult_inputs_folder_mp4
slp_path_obj = Path(slp_path)
output_path = test_sleap_track_output_folder
output_path_obj = Path(output_path)

# Create sleap-track command
args = (
f"{slp_path} --model {min_centroid_model_path} "
f"--tracking.tracker simple "
f"-o {output_path} "
f"--model {min_centered_instance_model_path} --video.index 0 --frames 1-3 --cpu"
).split()

if Path.is_dir(slp_path_obj):
slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
slp_path_list.append(file_path)
elif Path.is_file(slp_path_obj):
slp_path_list = [args.data_path]
slp_path_list = [file for file in slp_path_obj.iterdir() if file.is_file()]

for output_path in slp_path_list:
assert Path(output_path).exists()
output_path_list = [file for file in output_path_obj.iterdir() if file.is_file()]

# Run inference
sleap_track(args=args)
slp_path = Path(slp_path)

# Assert predictions file exists
if Path.is_dir(slp_path_obj):
new_slp_path_list = []
for file_path in slp_path_obj.iterdir():
if file_path.is_file():
new_slp_path_list.append(file_path)
elif Path.is_file(slp_path):
new_slp_path_list = [args.data_path]
new_output_path_list = [file for file in output_path_obj.iterdir() if file.is_file()]

for output_path in slp_path_list:
assert Path(output_path).exists()
# Check if there are any files in the directory
expected_extensions = {'.slp', '.mp4', '.avi'} # Add other video formats if necessary

for file_path in slp_path_list:
if file_path.suffix in expected_extensions:
expected_output_file = Path(output_path) / (file_path.stem + ".predictions.slp")
print(f"PATH: {expected_output_file}")
assert Path(expected_output_file).exists()


#files_to_remove = set(new_slp_path_list) - set(slp_path_list)
#for file in files_to_remove:
# file.unlink()
files_to_remove = set(new_output_path_list) - set(output_path_list)
for file in files_to_remove:
file.unlink()


def test_flow_tracker(centered_pair_predictions: Labels, tmpdir):
Expand Down

0 comments on commit 35db452

Please sign in to comment.