Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing inference on multiple videos via sleap-track #1784

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
132a1ea
implementing proposed code changes from issue #1777
emdavis02 May 24, 2024
e867ec0
comments
emdavis02 May 24, 2024
babaa77
configuring output_path to support multiple video inputs
emdavis02 May 24, 2024
83f444a
fixing errors from preexisting test cases
emdavis02 May 24, 2024
b0ac880
Test case / code fixes
emdavis02 Jun 12, 2024
dcc7a63
extending test cases for mp4 folders
emdavis02 Jun 13, 2024
35db452
test case for output directory
emdavis02 Jun 18, 2024
6f0c929
black and code rabbit fixes
emdavis02 Jun 24, 2024
bd2b016
code rabbit fixes
emdavis02 Jun 24, 2024
ec4c26d
as_posix errors resolved
emdavis02 Jun 27, 2024
abdc57c
syntax error
emdavis02 Jul 8, 2024
5ffdc96
adding test data
emdavis02 Jul 8, 2024
f179f5e
black
emdavis02 Jul 8, 2024
af565cb
output error resolved
emdavis02 Jul 8, 2024
8568cc3
edited for push to dev branch
emdavis02 Jul 8, 2024
ead7af8
black
emdavis02 Jul 8, 2024
8f0df1c
errors fixed, test cases implemented
emdavis02 Jul 8, 2024
760059f
invalid output test and invalid input test
emdavis02 Jul 9, 2024
ff706d8
deleting debugging statements
emdavis02 Jul 9, 2024
beb5e1e
deleting print statements
emdavis02 Jul 9, 2024
55bfe4b
black
emdavis02 Jul 10, 2024
3b9cd45
deleting unnecessary test case
emdavis02 Jul 10, 2024
be02a7d
implemented tmpdir
emdavis02 Jul 10, 2024
6a481c3
deleting extraneous file
emdavis02 Jul 10, 2024
488edde
fixing broken test case
emdavis02 Jul 12, 2024
4443686
fixing test_sleap_track_invalid_output
emdavis02 Jul 12, 2024
d86123d
removing support for multiple slp files
emdavis02 Jul 15, 2024
ae11b8d
implementing talmo's comments
emdavis02 Jul 15, 2024
fb587e5
adding comments
emdavis02 Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 179 additions & 85 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5288,46 +5288,76 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
A tuple of `(provider, data_path)` with the data `Provider` and path to the data
that was specified in the args.
"""

# Figure out which input path to use.
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
labels_path = getattr(args, "labels", None)
if labels_path is not None:
data_path = labels_path
data_path = Path(labels_path)
else:
data_path = Path(args.data_path)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize conditional assignment using a ternary operator.

The current if-else block for setting data_path can be simplified using a ternary operator, which makes the code cleaner and more concise.

-    if labels_path is not None:
-        data_path = Path(labels_path)
-    else:
-        data_path = Path(args.data_path)
+    data_path = Path(labels_path) if labels_path is not None else Path(args.data_path)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Figure out which input path to use.
labels_path = getattr(args, "labels", None)
if labels_path is not None:
data_path = labels_path
data_path = Path(labels_path)
else:
data_path = Path(args.data_path)
# Figure out which input path to use.
labels_path = getattr(args, "labels", None)
data_path = Path(labels_path) if labels_path is not None else Path(args.data_path)
Tools
Ruff

5294-5297: Use ternary operator data_path = Path(labels_path) if labels_path is not None else Path(args.data_path) instead of if-else-block

Replace if-else-block with data_path = Path(labels_path) if labels_path is not None else Path(args.data_path)

(SIM108)

# Check for multiple video inputs
# Compile file(s) into a list for later itteration
if data_path.is_dir():
data_path_list = []
for file_path in data_path.iterdir():
if file_path.is_file():
data_path_list.append(Path(file_path))
elif data_path.is_file():
data_path_list = [data_path]
else:
data_path = args.data_path

if data_path is None or data_path == "":
raise ValueError(
"You must specify a path to a video or a labels dataset. "
"Run 'sleap-track -h' to see full command documentation."
)

if data_path.endswith(".slp"):
labels = sleap.load_file(data_path)

if args.only_labeled_frames:
provider = LabelsReader.from_user_labeled_frames(labels)
elif args.only_suggested_frames:
provider = LabelsReader.from_unlabeled_suggestions(labels)
elif getattr(args, "video.index") != "":
provider = VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
# Provider list to accomodate multiple video inputs
provider_list = []
tmp_data_path_list = []
for data_path_file in data_path_list:
# Create a provider for each file
if data_path_file.as_posix().endswith(".slp"):
print(f"Sleap file: {data_path_file}")
labels = sleap.load_file(data_path_file.as_posix())

if args.only_labeled_frames:
provider_list.append(LabelsReader.from_user_labeled_frames(labels))
elif args.only_suggested_frames:
provider_list.append(LabelsReader.from_unlabeled_suggestions(labels))
elif getattr(args, "video.index") != "":
provider_list.append(
VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
)
else:
provider_list.append(LabelsReader(labels))

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider refactoring provider creation for clarity and modularity.

The code for creating providers based on file types is spread across multiple conditional branches. Consider refactoring this into a separate function or method to improve clarity and maintainability.

def create_provider(data_path_file, args):
    if data_path_file.as_posix().endswith(".slp"):
        labels = sleap.load_file(data_path_file.as_posix())
        # Additional conditions based on args
    else:
        video_kwargs = dict(
            dataset=vars(args).get("video.dataset"),
            input_format=vars(args).get("video.input_format"),
        )
        return VideoReader.from_filepath(filename=data_path_file.as_posix(), example_indices=frame_list(args.frames), **video_kwargs)

    return None  # or appropriate return

Also applies to: 5347-5352

tmp_data_path_list.append(data_path_file)

else:
provider = LabelsReader(labels)
try:
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
provider_list.append(
VideoReader.from_filepath(
filename=data_path_file.as_posix(),
example_indices=frame_list(args.frames),
**video_kwargs,
)
)
print(f"Video: {data_path_file}")
tmp_data_path_list.append(data_path_file)
# TODO: Clean this up.
except Exception:
print(f"Error reading file: {data_path_file}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve error handling in file reading.

Currently, the error handling in the file reading process just prints an error message. It might be beneficial to either raise an exception or handle the error in a way that doesn't stop the execution, depending on the intended behavior of the application.

-            except Exception:
-                print(f"Error reading file: {data_path_file}")
+            except Exception as e:
+                logging.error(f"Failed to read file {data_path_file}: {e}")
+                continue  # or other handling logic

Committable suggestion was skipped due to low confidence.


else:
print(f"Video: {data_path}")
# TODO: Clean this up.
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
provider = VideoReader.from_filepath(
filename=data_path, example_indices=frame_list(args.frames), **video_kwargs
)
data_path_list = tmp_data_path_list
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

return provider, data_path
return provider_list, data_path_list
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic to handle multiple video inputs looks good.

The logic to compile files into a list for iteration and the creation of providers based on file types are well implemented. However, consider refactoring the provider creation into a separate function for better modularity and clarity.

def create_provider(data_path_file, args):
    if data_path_file.as_posix().endswith(".slp"):
        labels = sleap.load_file(data_path_file.as_posix())
        # Additional conditions based on args
    else:
        video_kwargs = dict(
            dataset=vars(args).get("video.dataset"),
            input_format=vars(args).get("video.input_format"),
        )
        return VideoReader.from_filepath(filename=data_path_file.as_posix(), example_indices=frame_list(args.frames), **video_kwargs)

    return None  # or appropriate return



def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Expand All @@ -5339,6 +5369,7 @@ def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Returns:
The `Predictor` created from loaded models.
"""
print(args)
talmo marked this conversation as resolved.
Show resolved Hide resolved
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
peak_threshold = None
for deprecated_arg in [
"single.peak_threshold",
Expand Down Expand Up @@ -5419,11 +5450,9 @@ def main(args: Optional[list] = None):
# Parse inputs.
args, _ = parser.parse_known_args(args)
print("Args:")
pprint(vars(args))
print(vars(args))
talmo marked this conversation as resolved.
Show resolved Hide resolved
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
print()

output_path = args.output

# Setup devices.
if args.cpu or not sleap.nn.system.is_gpu_system():
sleap.nn.system.use_cpu_only()
Expand Down Expand Up @@ -5461,43 +5490,141 @@ def main(args: Optional[list] = None):
print()

# Setup data loader.
provider, data_path = _make_provider_from_cli(args)
provider_list, data_path_list = _make_provider_from_cli(args)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

# Setup tracker.
tracker = _make_tracker_from_cli(args)

output_path = args.output
if output_path is not None:
output_path_obj = Path(output_path)

# Output path given is a file, but multiple inputs were given
if output_path is not None and (
#TODO check if directory exists
#always specify output directory
Path.is_file(output_path_obj) and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)

if args.models is not None and "movenet" in args.models[0]:
args.models = args.models[0]

# Either run inference (and tracking) or just run tracking
# Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run)
if args.models is not None:
# Setup models.
predictor = _make_predictor_from_cli(args)
predictor.tracker = tracker

# Run inference!
labels_pr = predictor.predict(provider)
# Run inference on all files inputed
for data_path, provider in zip(data_path_list, provider_list):
# Setup models.
predictor = _make_predictor_from_cli(args)
print(f"predictor.tracker: {tracker}")
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
predictor.tracker = tracker

# Run inference!
labels_pr = predictor.predict(provider)

if output_path is None:
output_path = data_path + ".predictions.slp"
if output_path is None:

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
output_path = data_path.parent / (data_path.stem + ".predictions.slp")
output_path_obj = Path(output_path)

else:
output_path = output_path + "/" + (data_path.stem + ".predictions.slp")
output_path_obj = Path(output_path)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
print(labels_pr.provenance)
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

# Reset output_path for next iteration
output_path = args.output

# running tracking on existing prediction file
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve error handling for missing tracker specification.

The error message is clear, but consider providing more guidance or a direct link to documentation on how to specify a tracker.

raise ValueError(
    "To retrack on predictions, must specify tracker. "
    "Use 'sleap-track --tracking.tracker ...' to specify tracker to use. "
    "See [documentation link] for more details."
)

elif getattr(args, "tracking.tracker") is not None:
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(args.data_path)
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
for data_path, provider in zip(data_path_list, provider_list):
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(data_path.as_posix())
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)
labels_pr = Labels(labeled_frames=frames)

labels_pr = Labels(labeled_frames=frames)
if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"

if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

# Reset output_path for next iteration
output_path = args.output
emdavis02 marked this conversation as resolved.
Show resolved Hide resolved
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review output path handling logic.

The handling of the output path, especially when multiple inputs are given, needs careful attention. Consider refining the logic to ensure that the output path is correctly set in all scenarios.

if output_path is not None and Path.is_file(output_path_obj) and len(data_path_list) > 1:
    raise ValueError(
        "output_path argument must be a directory if multiple video inputs are given"
    )
# Ensure that the output path is reset correctly after each iteration
+ original_output_path = args.output
for data_path, provider in zip(data_path_list, provider_list):
    ...
-    output_path = args.output
+    output_path = original_output_path
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
provider_list, data_path_list = _make_provider_from_cli(args)
# Setup tracker.
tracker = _make_tracker_from_cli(args)
output_path = args.output
if output_path is not None:
output_path_obj = Path(output_path)
# Output path given is a file, but multiple inputs were given
if output_path is not None and (
#TODO check if directory exists
#always specify output directory
Path.is_file(output_path_obj) and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)
if args.models is not None and "movenet" in args.models[0]:
args.models = args.models[0]
# Either run inference (and tracking) or just run tracking
# Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run)
if args.models is not None:
# Setup models.
predictor = _make_predictor_from_cli(args)
predictor.tracker = tracker
# Run inference!
labels_pr = predictor.predict(provider)
# Run inference on all files inputed
for data_path, provider in zip(data_path_list, provider_list):
# Setup models.
predictor = _make_predictor_from_cli(args)
print(f"predictor.tracker: {tracker}")
predictor.tracker = tracker
# Run inference!
labels_pr = predictor.predict(provider)
if output_path is None:
output_path = data_path + ".predictions.slp"
if output_path is None:
labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
output_path = data_path.parent / (data_path.stem + ".predictions.slp")
output_path_obj = Path(output_path)
else:
output_path = output_path + "/" + (data_path.stem + ".predictions.slp")
output_path_obj = Path(output_path)
labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()
finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")
# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp
print("Provenance:")
print(labels_pr.provenance)
print()
labels_pr.provenance["args"] = vars(args)
# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)
if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
# Reset output_path for next iteration
output_path = args.output
# running tracking on existing prediction file
elif getattr(args, "tracking.tracker") is not None:
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(args.data_path)
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
for data_path, provider in zip(data_path_list, provider_list):
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(data_path.as_posix())
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)
print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)
labels_pr = Labels(labeled_frames=frames)
labels_pr = Labels(labeled_frames=frames)
if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()
finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")
# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp
print("Provenance:")
pprint(labels_pr.provenance)
print()
labels_pr.provenance["args"] = vars(args)
# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)
if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
# Reset output_path for next iteration
output_path = args.output
provider_list, data_path_list = _make_provider_from_cli(args)
# Setup tracker.
tracker = _make_tracker_from_cli(args)
output_path = args.output
if output_path is not None:
output_path_obj = Path(output_path)
# Output path given is a file, but multiple inputs were given
if output_path is not None and (
#TODO check if directory exists
#always specify output directory
Path.is_file(output_path_obj) and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)
if args.models is not None and "movenet" in args.models[0]:
args.models = args.models[0]
# Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run)
if args.models is not None:
# Run inference on all files inputed
original_output_path = args.output
for data_path, provider in zip(data_path_list, provider_list):
# Setup models.
predictor = _make_predictor_from_cli(args)
print(f"predictor.tracker: {tracker}")
predictor.tracker = tracker
# Run inference!
labels_pr = predictor.predict(provider)
if output_path is None:
output_path = data_path.parent / (data_path.stem + ".predictions.slp")
output_path_obj = Path(output_path)
else:
output_path = output_path + "/" + (data_path.stem + ".predictions.slp")
output_path_obj = Path(output_path)
labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()
finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")
# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp
print("Provenance:")
print(labels_pr.provenance)
print()
labels_pr.provenance["args"] = vars(args)
# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)
if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
# Reset output_path for next iteration
output_path = original_output_path
# running tracking on existing prediction file
elif getattr(args, "tracking.tracker") is not None:
for data_path, provider in zip(data_path_list, provider_list):
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(data_path.as_posix())
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)
labels_pr = Labels(labeled_frames=frames)
if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()
finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")
# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp
print("Provenance:")
pprint(labels_pr.provenance)
print()
labels_pr.provenance["args"] = vars(args)
# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)
if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
# Reset output_path for next iteration
output_path = args.output

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add output folder as an option when the input is a folder.

This suggestion addresses the need to specify an output folder when processing multiple video files from a directory.

Do you want me to generate the code to implement this feature or open a GitHub issue to track this task?

Comment on lines +5659 to +5660
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resetting output_path might lead to unintended behavior.

The reset of output_path at the end of each iteration could lead to unexpected behavior in subsequent iterations, especially if args.output is modified during the process.

+    original_output_path = args.output
     for data_path, provider in zip(data_path_list, provider_list):
         ...
-        output_path = args.output
+        output_path = original_output_path
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Reset output_path for next iteration
output_path = args.output
original_output_path = args.output
for data_path, provider in zip(data_path_list, provider_list):
...
output_path = original_output_path


else:
raise ValueError(
Expand All @@ -5506,36 +5633,3 @@ def main(args: Optional[list] = None):
"To retrack on predictions, must specify tracker. "
"Use \"sleap-track --tracking.tracker ...' to specify tracker to use."
)

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
Loading
Loading