Skip to content

Commit

Permalink
removing support for multiple slp files
Browse files Browse the repository at this point in the history
  • Loading branch information
emdavis02 committed Jul 15, 2024
1 parent 4443686 commit d86123d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 51 deletions.
52 changes: 24 additions & 28 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5315,7 +5315,10 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
output_data_path_list = []
for file_path in data_path_list:
# Create a provider for each file
if file_path.as_posix().endswith(".slp"):
if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1:
print(f"slp file skipped: {file_path.as_posix()}")

elif file_path.as_posix().endswith(".slp"):
labels = sleap.load_file(file_path.as_posix())

if args.only_labeled_frames:
Expand Down Expand Up @@ -5369,7 +5372,6 @@ def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Returns:
The `Predictor` created from loaded models.
"""
print(args)
peak_threshold = None
for deprecated_arg in [
"single.peak_threshold",
Expand Down Expand Up @@ -5450,7 +5452,7 @@ def main(args: Optional[list] = None):
# Parse inputs.
args, _ = parser.parse_known_args(args)
print("Args:")
print(vars(args))
pprint(vars(args))
print()

# Setup devices.
Expand Down Expand Up @@ -5514,29 +5516,24 @@ def main(args: Optional[list] = None):
labels_pr = predictor.predict(provider)

# if output path was not provided, create an output path
if output_path is not None:
output_path_obj = Path(output_path)
if (
output_path_obj.exists()
and output_path_obj.is_file()
and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)

if output_path is None:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path_obj = Path(output_path)

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
elif len(data_path_list) > 1:
output_path = (
output_path_obj.as_posix()
+ "/"
+ (data_path_obj.stem + ".predictions.slp")
)
else:
output_path_obj = Path(output_path)
if output_path_obj.is_file() and len(data_path_list) > 1:
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
if len(data_path_list) > 1:
output_path = (
output_path_obj
/ data_path_obj.with_suffix(".predictions.slp").name
)
output_path_obj = Path(output_path)

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
Expand All @@ -5562,7 +5559,7 @@ def main(args: Optional[list] = None):
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
print(labels_pr.provenance)
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)
Expand Down Expand Up @@ -5596,7 +5593,7 @@ def main(args: Optional[list] = None):
output_path = f"{data_path}.{tracker.get_name()}.slp"
output_path_obj = Path(output_path)

if output_path is not None:
else:
output_path_obj = Path(output_path)
if (
output_path_obj.exists()
Expand All @@ -5608,12 +5605,11 @@ def main(args: Optional[list] = None):
)

elif not output_path_obj.exists() and len(data_path_list) > 1:
output_path = (
output_path_obj.as_posix()
+ "/"
+ (data_path_obj.stem + ".predictions.slp")
output_path = output_path_obj / data_path_obj.with_suffix(
".predictions.slp"
)
output_path_obj = Path(output_path)
output_path_obj.parent.mkdir(exist_ok=True, parents=True)

if args.no_empty_frames:
# Clear empty frames if specified.
Expand All @@ -5636,7 +5632,7 @@ def main(args: Optional[list] = None):
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
print(labels_pr.provenance)
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)
Expand Down
26 changes: 3 additions & 23 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,9 +1557,7 @@ def test_sleap_track_mult_input_slp(

# Assert predictions file exists
expected_extensions = {
".slp",
".mp4",
".avi",
} # Add other video formats if necessary

for file_path in slp_path_list:
Expand Down Expand Up @@ -1605,14 +1603,8 @@ def test_sleap_track_mult_input_slp_mp4(
sleap_track(args=args)

# Assert predictions file exists
expected_extensions = {
".slp",
".mp4",
".avi",
} # Add other video formats if necessary

for file_path in slp_path_list:
if file_path.suffix in expected_extensions:
if file_path.suffix == ".mp4":
expected_output_file = f"{file_path}.predictions.slp"
assert Path(expected_output_file).exists()

Expand Down Expand Up @@ -1651,14 +1643,8 @@ def test_sleap_track_mult_input_mp4(
sleap_track(args=args)

# Assert predictions file exists
expected_extensions = {
".slp",
".mp4",
".avi",
} # Add other video formats if necessary

for file_path in slp_path_list:
if file_path.suffix in expected_extensions:
if file_path.suffix == ".mp4":
expected_output_file = f"{file_path}.predictions.slp"
assert Path(expected_output_file).exists()

Expand Down Expand Up @@ -1700,14 +1686,8 @@ def test_sleap_track_output_mult(
slp_path = Path(slp_path)

# Check if there are any files in the directory
expected_extensions = {
".slp",
".mp4",
".avi",
} # Add other video formats if necessary

for file_path in slp_path_list:
if file_path.suffix in expected_extensions:
if file_path.suffix == ".mp4":
expected_output_file = output_path_obj / (
file_path.stem + ".predictions.slp"
)
Expand Down

0 comments on commit d86123d

Please sign in to comment.