Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing inference on multiple videos via sleap-track #1784

Merged
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
132a1ea
implementing proposed code changes from issue #1777
emdavis02 May 24, 2024
e867ec0
comments
emdavis02 May 24, 2024
babaa77
configuring output_path to support multiple video inputs
emdavis02 May 24, 2024
83f444a
fixing errors from preexisting test cases
emdavis02 May 24, 2024
b0ac880
Test case / code fixes
emdavis02 Jun 12, 2024
dcc7a63
extending test cases for mp4 folders
emdavis02 Jun 13, 2024
35db452
test case for output directory
emdavis02 Jun 18, 2024
6f0c929
black and code rabbit fixes
emdavis02 Jun 24, 2024
bd2b016
code rabbit fixes
emdavis02 Jun 24, 2024
ec4c26d
as_posix errors resolved
emdavis02 Jun 27, 2024
abdc57c
syntax error
emdavis02 Jul 8, 2024
5ffdc96
adding test data
emdavis02 Jul 8, 2024
f179f5e
black
emdavis02 Jul 8, 2024
af565cb
output error resolved
emdavis02 Jul 8, 2024
8568cc3
edited for push to dev branch
emdavis02 Jul 8, 2024
ead7af8
black
emdavis02 Jul 8, 2024
8f0df1c
errors fixed, test cases implemented
emdavis02 Jul 8, 2024
760059f
invalid output test and invalid input test
emdavis02 Jul 9, 2024
ff706d8
deleting debugging statements
emdavis02 Jul 9, 2024
beb5e1e
deleting print statements
emdavis02 Jul 9, 2024
55bfe4b
black
emdavis02 Jul 10, 2024
3b9cd45
deleting unnecessary test case
emdavis02 Jul 10, 2024
be02a7d
implemented tmpdir
emdavis02 Jul 10, 2024
6a481c3
deleting extraneous file
emdavis02 Jul 10, 2024
488edde
fixing broken test case
emdavis02 Jul 12, 2024
4443686
fixing test_sleap_track_invalid_output
emdavis02 Jul 12, 2024
d86123d
removing support for multiple slp files
emdavis02 Jul 15, 2024
ae11b8d
implementing talmo's comments
emdavis02 Jul 15, 2024
fb587e5
adding comments
emdavis02 Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 134 additions & 85 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5288,46 +5288,56 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
A tuple of `(provider, data_path)` with the data `Provider` and path to the data
that was specified in the args.
"""

# Figure out which input path to use.
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
labels_path = getattr(args, "labels", None)
if labels_path is not None:
data_path = labels_path
else:
data_path = args.data_path

if data_path is None or data_path == "":

if Path.is_dir(data_path):
data_path_list = []
for file_path in data_path.iterdir():
if file_path.is_file():
data_path_list.append(file_path)
elif Path.is_file(data_path):
data_path_list = [args.data_path]
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(
"You must specify a path to a video or a labels dataset. "
"Run 'sleap-track -h' to see full command documentation."
)

if data_path.endswith(".slp"):
labels = sleap.load_file(data_path)

if args.only_labeled_frames:
provider = LabelsReader.from_user_labeled_frames(labels)
elif args.only_suggested_frames:
provider = LabelsReader.from_unlabeled_suggestions(labels)
elif getattr(args, "video.index") != "":
provider = VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
else:
provider = LabelsReader(labels)
provider_list = []
for data_path_file in data_path_list:
if data_path_file.endswith(".slp"):
labels = sleap.load_file(data_path_file)

if args.only_labeled_frames:
provider_list.append(LabelsReader.from_user_labeled_frames(labels))
elif args.only_suggested_frames:
provider_list.append(LabelsReader.from_unlabeled_suggestions(labels))
elif getattr(args, "video.index") != "":
provider_list.append(VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
))
else:
provider_list.append(LabelsReader(labels))

else:
print(f"Video: {data_path}")
# TODO: Clean this up.
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
provider = VideoReader.from_filepath(
filename=data_path, example_indices=frame_list(args.frames), **video_kwargs
)
else:
print(f"Video: {data_path_file}")
# TODO: Clean this up.
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
provider_list.append(VideoReader.from_filepath(
filename=data_path_file, example_indices=frame_list(args.frames), **video_kwargs
))

return provider, data_path
return provider_list, data_path_list


def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Expand Down Expand Up @@ -5461,7 +5471,7 @@ def main(args: Optional[list] = None):
print()

# Setup data loader.
provider, data_path = _make_provider_from_cli(args)
provider_list, data_path_list = _make_provider_from_cli(args)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# Setup tracker.
tracker = _make_tracker_from_cli(args)
Expand All @@ -5471,33 +5481,103 @@ def main(args: Optional[list] = None):

# Either run inference (and tracking) or just run tracking
if args.models is not None:
# Setup models.
predictor = _make_predictor_from_cli(args)
predictor.tracker = tracker

# Run inference!
labels_pr = predictor.predict(provider)

if output_path is None:
output_path = data_path + ".predictions.slp"

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
for data_path, provider in zip(data_path_list, provider_list):
# Setup models.
predictor = _make_predictor_from_cli(args)
predictor.tracker = tracker

# Run inference!
labels_pr = predictor.predict(provider)

if output_path is None:
output_path = data_path + ".predictions.slp"

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
print(labels_pr.provenance)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

output_path = None

elif getattr(args, "tracking.tracker") is not None:
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(args.data_path)
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)

labels_pr = Labels(labeled_frames=frames)

if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
for data_path in data_path_list:
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(args.data_path)
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)

labels_pr = Labels(labeled_frames=frames)

if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

else:
raise ValueError(
Expand All @@ -5507,35 +5587,4 @@ def main(args: Optional[list] = None):
"Use \"sleap-track --tracking.tracker ...' to specify tracker to use."
)

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refine the error message to provide more context and guidance.

- "Use \"sleap-track --tracking.tracker ...' to specify tracker to use."
+ "Specify the tracker using 'sleap-track --tracking.tracker'. Refer to the documentation for more details."

Committable suggestion was skipped due to low confidence.