Skip to content

Commit

Permalink
remove unnecessary for loop, adjust tests, change default threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
grquach committed Jul 17, 2024
2 parents 8cc046c + a17e5c8 commit 3baf219
Show file tree
Hide file tree
Showing 31 changed files with 114 additions and 17 deletions.
6 changes: 4 additions & 2 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ optional arguments:

```none
usage: sleap-train [-h] [--video-paths VIDEO_PATHS] [--val_labels VAL_LABELS]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--test_labels TEST_LABELS] [--tensorboard] [--save_viz]
[--keep_viz] [--zmq] [--run_name RUN_NAME] [--prefix PREFIX]
[--suffix SUFFIX]
training_job_path [labels_path]
Expand Down Expand Up @@ -68,6 +68,8 @@ optional arguments:
--save_viz Enable saving of prediction visualizations to the run
folder if not already specified in the training job
config.
--keep_viz Keep prediction visualization images in the run
folder after training if --save_viz is enabled.
--zmq Enable ZMQ logging (for GUI) if not already specified
in the training job config.
--run_name RUN_NAME Run name to use when saving file, overrides other run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
" \"runs_folder\": \"models\",\n",
" \"tags\": [],\n",
" \"save_visualizations\": true,\n",
" \"delete_viz_images\": true,\n",
" \"keep_viz_images\": true,\n",
" \"zip_outputs\": false,\n",
" \"log_to_csv\": true,\n",
" \"checkpointing\": {\n",
Expand Down Expand Up @@ -727,7 +727,7 @@
" \"runs_folder\": \"models\",\n",
" \"tags\": [],\n",
" \"save_visualizations\": true,\n",
" \"delete_viz_images\": true,\n",
" \"keep_viz_images\": true,\n",
" \"zip_outputs\": false,\n",
" \"log_to_csv\": true,\n",
" \"checkpointing\": {\n",
Expand Down
5 changes: 5 additions & 0 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ training:
type: bool
default: true

- name: _keep_viz
label: Keep Prediction Visualization Images After Training
type: bool
default: false

- name: _predict_frames
label: Predict On
type: list
Expand Down
17 changes: 14 additions & 3 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Run training/inference in background process via CLI."""

import abc
import attr
import os
Expand Down Expand Up @@ -500,9 +501,11 @@ def write_pipeline_files(
"data_path": os.path.basename(data_path),
"models": [Path(p).as_posix() for p in new_cfg_filenames],
"output_path": prediction_output_path,
"type": "labels"
if type(item_for_inference) == DatasetItemForInference
else "video",
"type": (
"labels"
if type(item_for_inference) == DatasetItemForInference
else "video"
),
"only_suggested_frames": only_suggested_frames,
"tracking": tracking_args,
}
Expand Down Expand Up @@ -544,6 +547,7 @@ def run_learning_pipeline(
"""

save_viz = inference_params.get("_save_viz", False)
keep_viz = inference_params.get("_keep_viz", False)

if "movenet" in inference_params["_pipeline"]:
trained_job_paths = [inference_params["_pipeline"]]
Expand All @@ -557,6 +561,7 @@ def run_learning_pipeline(
inference_params=inference_params,
gui=True,
save_viz=save_viz,
keep_viz=keep_viz,
)

# Check that all the models were trained
Expand Down Expand Up @@ -585,6 +590,7 @@ def run_gui_training(
inference_params: Dict[str, Any],
gui: bool = True,
save_viz: bool = False,
keep_viz: bool = False,
) -> Dict[Text, Text]:
"""
Runs training for each training job.
Expand All @@ -594,6 +600,7 @@ def run_gui_training(
config_info_list: List of ConfigFileInfo with configs for training.
gui: Whether to show gui windows and process gui events.
save_viz: Whether to save visualizations from training.
keep_viz: Whether to keep prediction visualization images after training.
Returns:
Dictionary, keys are head name, values are path to trained config.
Expand Down Expand Up @@ -683,6 +690,7 @@ def waiting():
video_paths=video_path_list,
waiting_callback=waiting,
save_viz=save_viz,
keep_viz=keep_viz,
)

if ret == "success":
Expand Down Expand Up @@ -825,6 +833,7 @@ def train_subprocess(
video_paths: Optional[List[Text]] = None,
waiting_callback: Optional[Callable] = None,
save_viz: bool = False,
keep_viz: bool = False,
):
"""Runs training inside subprocess."""
run_path = job_config.outputs.run_path
Expand Down Expand Up @@ -853,6 +862,8 @@ def train_subprocess(

if save_viz:
cli_args.append("--save_viz")
if keep_viz:
cli_args.append("--keep_viz")

# Use cli arg since cli ignores setting in config
if job_config.outputs.tensorboard.write_logs:
Expand Down
6 changes: 3 additions & 3 deletions sleap/nn/config/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ class OutputsConfig:
save_visualizations: If True, will render and save visualizations of the model
predictions as PNGs to "{run_folder}/viz/{split}.{epoch:04d}.png", where the
split is one of "train", "validation", "test".
delete_viz_images: If True, delete the saved visualizations after training
completes. This is useful to reduce the model folder size if you do not need
keep_viz_images: If True, keep the saved visualization images after training
completes. This is useful unchecked to reduce the model folder size if you do not need
to keep the visualization images.
zip_outputs: If True, compress the run folder to a zip file. This will be named
"{run_folder}.zip".
Expand All @@ -170,7 +170,7 @@ class OutputsConfig:
runs_folder: Text = "models"
tags: List[Text] = attr.ib(factory=list)
save_visualizations: bool = True
delete_viz_images: bool = True
keep_viz_images: bool = False
zip_outputs: bool = False
log_to_csv: bool = True
checkpointing: CheckpointingConfig = attr.ib(factory=CheckpointingConfig)
Expand Down
13 changes: 11 additions & 2 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def train(self):
if self.config.outputs.save_outputs:
if (
self.config.outputs.save_visualizations
and self.config.outputs.delete_viz_images
and not self.config.outputs.keep_viz_images
):
self.cleanup()

Expand Down Expand Up @@ -997,7 +997,7 @@ def cleanup(self):

def package(self):
"""Package model folder into a zip file for portability."""
if self.config.outputs.delete_viz_images:
if not self.config.outputs.keep_viz_images:
self.cleanup()
logger.info(f"Packaging results to: {self.run_path}.zip")
shutil.make_archive(
Expand Down Expand Up @@ -1864,6 +1864,14 @@ def create_trainer_using_cli(args: Optional[List] = None):
"already specified in the training job config."
),
)
parser.add_argument(
"--keep_viz",
action="store_true",
help=(
"Keep prediction visualization images in the run folder after training when "
"--save_viz is enabled."
),
)
parser.add_argument(
"--zmq",
action="store_true",
Expand Down Expand Up @@ -1949,6 +1957,7 @@ def create_trainer_using_cli(args: Optional[List] = None):
if args.suffix != "":
job_config.outputs.run_name_suffix = args.suffix
job_config.outputs.save_visualizations |= args.save_viz
job_config.outputs.keep_viz_images = args.keep_viz
if args.labels_path == "":
args.labels_path = None
args.video_paths = args.video_paths.split(",")
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline.centroid.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_large_rf.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_large_rf.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_large_rf.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_medium_rf.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_medium_rf.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/baseline_medium_rf.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.bottomup.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.centroid.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.single.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 1 addition & 0 deletions sleap/training_profiles/pretrained.topdown.json
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": true,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"delete_viz_images": true,
"keep_viz_images": false,
"zip_outputs": false,
"log_to_csv": true,
"checkpointing": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"delete_viz_images": true,
"keep_viz_images": false,
"zip_outputs": false,
"log_to_csv": true,
"checkpointing": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
"runs_folder": "models",
"tags": [],
"save_visualizations": false,
"keep_viz_images": false,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
""
],
"save_visualizations": false,
"keep_viz_images": true,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
""
],
"save_visualizations": false,
"keep_viz_images": true,
"log_to_csv": true,
"checkpointing": {
"initial_model": false,
Expand Down
1 change: 0 additions & 1 deletion tests/gui/test_dialogs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Module to test the dialogs of the GUI (contained in sleap/gui/dialogs)."""


import os
from pathlib import Path

Expand Down
Loading

0 comments on commit 3baf219

Please sign in to comment.