Skip to content

Commit

Permalink
v1.3.0 updates (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed Apr 19, 2024
1 parent d8b1069 commit 283f653
Show file tree
Hide file tree
Showing 27 changed files with 578 additions and 709 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
![GitHub](https://img.shields.io/github/license/Lightning-Universe/Pose-app)
[![Documentation Status](https://readthedocs.org/projects/pose-app/badge/?version=latest)](https://pose-app.readthedocs.io/en/latest/?badge=latest)

This repo contains browser-based GUIs that facilitate the deveopment of a pose estimation project.
This repo contains a browser-based GUI that facilitate the deveopment of a pose estimation project.

We provide three different apps:
* `demo_app.py`: using provided example data, train and evaluate pose estimation models
* `labeling_app.py`: stand-alone labeling app where you can upload videos, extract frames, and annotate keypoints on extracted frames using LabelStudio
* `app.py`: full app that includes labeling, training, and evaluation
The app allows you to

* label data (upload videos, extract frames, annotate keypoints)
* train and evaluate models
* run inference on new videos

Additionally, the app comes with an example dataset if you want to explore, without needing to
label your own data.

Preprint: [Lightning Pose: improved animal pose estimation via semi-supervised learning, Bayesian ensembling, and cloud-native open-source tools](https://www.biorxiv.org/content/10.1101/2023.04.28.538703v1)

Expand Down
156 changes: 149 additions & 7 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,23 @@
"""

from lightning.app import CloudCompute, LightningApp, LightningFlow
from lightning.app.structures import Dict
import logging
import numpy as np
import os
import pandas as pd
import shutil
import sys
import time
import yaml

from lightning_pose_app import LABELSTUDIO_DB_DIR, LIGHTNING_POSE_DIR
from lightning_pose_app import (
COLLECTED_DATA_FILENAME,
LABELED_DATA_DIR,
LABELSTUDIO_DB_DIR,
LIGHTNING_POSE_DIR,
MODELS_DIR,
SELECTED_FRAMES_FILENAME,
)
from lightning_pose_app.bashwork import LitBashWork
from lightning_pose_app.label_studio.component import LitLabelStudio
from lightning_pose_app.ui.extract_frames import ExtractFramesUI
Expand Down Expand Up @@ -85,8 +94,142 @@ def __init__(self):
database_dir=os.path.join(self.data_dir, LABELSTUDIO_DB_DIR),
)

# works for inference
self.inference = Dict()
# start label studio
self.label_studio.run(action="start_label_studio")

# import mirror-mouse-example dataset
if not os.environ.get("TESTING_LAI"):
self.import_demo_dataset(
src_dir=os.path.join(LIGHTNING_POSE_DIR, "data", "mirror-mouse-example"),
dst_dir=os.path.join(self.data_dir[1:], "mirror-mouse-example")
)

def import_demo_dataset(self, src_dir, dst_dir):

src_dir_abs = os.path.join(os.path.dirname(__file__), src_dir)
proj_dir_abs = os.path.join(os.path.dirname(__file__), dst_dir)
if os.path.isdir(proj_dir_abs):
return

_logger.info("Importing demo dataset; this will only take a minute")

project_name = os.path.basename(dst_dir)

# -------------------------------
# copy data
# -------------------------------
# copy full example data directory over
shutil.copytree(src_dir_abs, proj_dir_abs)

# copy config file
config_file_dst = os.path.join(proj_dir_abs, f"model_config_{project_name}.yaml")
shutil.copyfile(
os.path.join(LIGHTNING_POSE_DIR, "scripts", "configs", f"config_{project_name}.yaml"),
config_file_dst,
)

# make csv file for label studio
n_frames = len(os.listdir(os.path.join(proj_dir_abs, LABELED_DATA_DIR)))
idxs_selected = np.arange(1, n_frames - 2) # we've stored mock context frames
n_digits = 2
extension = "png"
frames_to_label = np.sort(np.array(
["img%s.%s" % (str(idx).zfill(n_digits), extension) for idx in idxs_selected]
))
np.savetxt(
os.path.join(proj_dir_abs, LABELED_DATA_DIR, SELECTED_FRAMES_FILENAME),
frames_to_label,
delimiter=",",
fmt="%s",
)

# make models dir
os.makedirs(os.path.join(proj_dir_abs, MODELS_DIR), exist_ok=True)

# -------------------------------
# remove obstacle keypoints
# -------------------------------
config_dict = yaml.safe_load(open(config_file_dst))
config_dict["data"]["keypoint_names"] = [
"paw1LH_top",
"paw2LF_top",
"paw3RF_top",
"paw4RH_top",
"tailBase_top",
"tailMid_top",
"nose_top",
"paw1LH_bot",
"paw2LF_bot",
"paw3RF_bot",
"paw4RH_bot",
"tailBase_bot",
"tailMid_bot",
"nose_bot",
]
config_dict["data"]["columns_for_singleview_pca"] = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
]
config_dict["data"]["mirrored_column_matches"] = [
[0, 1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12, 13],
]
config_dict["losses"]["temporal"]["epsilon"] = 10
yaml.dump(config_dict, open(config_file_dst, "w"))

csv_file = os.path.join(proj_dir_abs, COLLECTED_DATA_FILENAME)
df = pd.read_csv(csv_file, index_col=0, header=[0, 1, 2])
df.drop('obs_top', axis=1, level=1, inplace=True)
df.drop('obsHigh_bot', axis=1, level=1, inplace=True)
df.drop('obsLow_bot', axis=1, level=1, inplace=True)
df.to_csv(csv_file)

# -------------------------------
# import project to labelstudio
# -------------------------------
# create project flow to help with upload
project_ui_demo = ProjectUI(
data_dir=self.data_dir,
default_config_dict=self.project_ui.default_config_dict,
debug=False, # if True, hard-code project details like n_views, keypoint_names, etc.
)
# update paths
project_ui_demo.run(action="update_paths", project_name=project_name)
# load project defaults
project_ui_demo.run(action="update_project_config")
# make keypoints field
keypoint_names = project_ui_demo.config_dict["data"]["keypoint_names"]
project_ui_demo.run(
action="update_project_config",
new_vals_dict={
"data": {"keypoints": keypoint_names, "num_keypoints": len(keypoint_names)}
},
)

# import to labelstudio
self.label_studio.run(
action="update_paths",
proj_dir=project_ui_demo.proj_dir,
proj_name=project_name,
)
self.label_studio.run(
action="create_labeling_config_xml",
keypoints=project_ui_demo.config_dict["data"]["keypoints"],
)
self.label_studio.run(action="create_new_project")
self.label_studio.run(action="import_existing_annotations")

# -------------------------------
# cleanup - reset labelstudio
# -------------------------------
self.label_studio.proj_dir = None
self.label_studio.proj_name = None
self.label_studio.keypoints = None
for key, val in self.label_studio.filenames.items():
self.label_studio.filenames[key] = ""
self.label_studio.counts["create_new_project"] = 0
self.label_studio.counts["import_existing_annotations"] = 0

del project_ui_demo

def start_tensorboard(self, logdir):
"""run tensorboard"""
Expand Down Expand Up @@ -129,7 +272,6 @@ def run(self):
# -------------------------------------------------------------
# start background services (run only once)
# -------------------------------------------------------------
self.label_studio.run(action="start_label_studio")
self.start_fiftyone()
if self.project_ui.model_dir is not None:
# find previously trained models for project, expose to training and diagnostics UIs
Expand All @@ -138,7 +280,7 @@ def run(self):
# only launch once we know which project we're working on
self.start_tensorboard(logdir=self.project_ui.model_dir[1:])
self.streamlit_frame.run(action="initialize")
self.streamlit_video.run(action="initialize")
self.streamlit_video.run(action="initialize")

# -------------------------------------------------------------
# update project data (user has clicked button in project UI)
Expand Down Expand Up @@ -312,9 +454,9 @@ def configure_layout(self):
train_tab,
train_status_tab,
st_frame_tab,
fo_tab,
st_video_tab,
st_video_player_tab,
fo_tab,
]
else:
return [
Expand Down
Loading

0 comments on commit 283f653

Please sign in to comment.