Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uploading projects bugs fixing #65

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightning-pose
144 changes: 85 additions & 59 deletions lightning_pose_app/ui/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import time
import zipfile
import re

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -281,7 +282,7 @@ def _upload_existing_project(self, **kwargs):
return
with zipfile.ZipFile(self.st_upload_existing_project_zippath) as z:
unzipped_dir = self.st_upload_existing_project_zippath.replace(".zip", "")
z.extractall(path=os.path.dirname(self.st_upload_existing_project_zippath))
z.extractall(path=unzipped_dir)

def contains_videos(file_or_dir):
if os.path.isfile(file_or_dir):
Expand All @@ -293,49 +294,70 @@ def contains_videos(file_or_dir):
else:
return False

# copy files over; not great that this is in a Flow, might take time
if self.st_existing_project_format == "Lightning Pose":
files_and_dirs = os.listdir(unzipped_dir)
for file_or_dir in files_and_dirs:
src = os.path.join(unzipped_dir, file_or_dir)
if file_or_dir.endswith(".csv"):
# copy labels csv file
dst = os.path.join(self.proj_dir_abs, COLLECTED_DATA_FILENAME)
shutil.copyfile(src, dst)
elif contains_videos(src):
# copy videos over, make sure they are in proper format
dst_dir = os.path.join(self.proj_dir_abs, file_or_dir)
print(src)
print(dst_dir)
copy_and_reformat_video_directory(src_dir=src, dst_dir=dst_dir)
else:
# copy other files
dst = os.path.join(self.proj_dir_abs, file_or_dir)
if os.path.isdir(src):
def find_top_level_dir(initial_path, target_file_name):
for root, dirs, files in os.walk(initial_path, topdown=True):
if target_file_name in files:
return root

finished_copy_files = False
try:
if self.st_existing_project_format == "Lightning Pose":
top_level_dir = find_top_level_dir(unzipped_dir, COLLECTED_DATA_FILENAME)
files_and_dirs = os.listdir(top_level_dir)
for file_or_dir in files_and_dirs:
src = os.path.join(top_level_dir, file_or_dir)
if file_or_dir.endswith(".csv"):
# copy labels csv file
dst = os.path.join(self.proj_dir_abs, COLLECTED_DATA_FILENAME)
shutil.copyfile(src, dst)
elif contains_videos(src):
# copy videos over, make sure they are in proper format
dst_dir = os.path.join(self.proj_dir_abs, file_or_dir)
copy_and_reformat_video_directory(src_dir=src, dst_dir=dst_dir)
else:
# copy other files
dst = os.path.join(self.proj_dir_abs, file_or_dir)
if os.path.isdir(src):
shutil.copytree(src, dst)
else:
shutil.copyfile(src, dst)

# flag finish coping all files
finished_copy_files = True

elif self.st_existing_project_format == "DLC":

# copy files
files_and_dirs = os.listdir(unzipped_dir)
req_dlc_dirs = ["labeled-data", "videos"]
for d in req_dlc_dirs:
assert d in files_and_dirs, \
f"zipped DLC directory must include folder named {d}"
src = os.path.join(unzipped_dir, d)
dst = os.path.join(self.proj_dir_abs, d)
if d == "labeled-data":
shutil.copytree(src, dst)
else:
shutil.copyfile(src, dst)
copy_and_reformat_video_directory(src_dir=src, dst_dir=dst)

elif self.st_existing_project_format == "DLC":

# copy files
files_and_dirs = os.listdir(unzipped_dir)
req_dlc_dirs = ["labeled-data", "videos"]
for d in req_dlc_dirs:
assert d in files_and_dirs, f"zipped DLC directory must include folder named {d}"
src = os.path.join(unzipped_dir, d)
dst = os.path.join(self.proj_dir_abs, d)
if d == "labeled-data":
shutil.copytree(src, dst)
else:
copy_and_reformat_video_directory(src_dir=src, dst_dir=dst)
# create single csv file of labels out of video-specific label files
df_all = collect_dlc_labels(self.proj_dir_abs)
df_all.to_csv(os.path.join(self.proj_dir_abs, COLLECTED_DATA_FILENAME))

# create single csv file of labels out of video-specific label files
df_all = collect_dlc_labels(self.proj_dir_abs)
df_all.to_csv(os.path.join(self.proj_dir_abs, COLLECTED_DATA_FILENAME))
# flag finish coping all files
finished_copy_files = True
else:
raise NotImplementedError("Can only import 'Lightning Pose' or 'DLC' projects")

else:
raise NotImplementedError("Can only import 'Lightning Pose' or 'DLC' projects")
# remove zipped file from project folder
if finished_copy_files:
if os.path.exists(self.st_upload_existing_project_zippath):
os.remove(self.st_upload_existing_project_zippath)
if os.path.isdir(unzipped_dir):
shutil.rmtree(unzipped_dir)

except Exception as e:
print(f"An error occurred: {e}")

# create 'selected_frames.csv' file for each video subdirectory
# this is required to import frames into label studio, so that we don't confuse context
Expand All @@ -362,7 +384,6 @@ def contains_videos(file_or_dir):
self.count_upload_existing += 1

def _delete_project(self, **kwargs):

# delete project locally
if os.path.exists(self.proj_dir_abs):
shutil.rmtree(self.proj_dir_abs)
Expand Down Expand Up @@ -427,28 +448,28 @@ def check_files_in_zipfile(filepath: str, project_type: str = "Lightning Pose")
if project_type not in ["DLC", "Lightning Pose"]:
raise NotImplementedError

expected_dirs = [VIDEOS_DIR, LABELED_DATA_DIR, COLLECTED_DATA_FILENAME]

error_flag = False
error_msg = ""
error_msgs = [] # Collect error messages in a list

with zipfile.ZipFile(filepath) as z:
zipname = os.path.basename(filepath).replace(".zip", "")
files = z.namelist()
if project_type == "Lightning Pose" or project_type == "DLC":
if os.path.join(zipname, LABELED_DATA_DIR, "") not in files:
error_flag = True
error_msg += f"""
ERROR: Your directory of labeled frames must be named "{LABELED_DATA_DIR}"
If you change this directory name, make sure to update the filepaths in the
labeled data csv file as well!
<br /><br />
"""
if os.path.join(zipname, VIDEOS_DIR, "") not in files:

# Iterate over each expected directory and check if it's present
for expected_dir in expected_dirs:
# Adjusting the logic to check the presence of directories correctly
if not any(f"{expected_dir}" in file for file in files):
error_flag = True
error_msg += f"""
ERROR: Your directory of videos must be named "{VIDEOS_DIR}" (can be empty)
<br /><br />
"""
else:
raise NotImplementedError
# Append specific error message for the missing directory
error_msgs.append(
f"ERROR: Your directory of {expected_dir} must be named "
f"\"{expected_dir}\" (can be empty)."
)

# Joining all error messages with breaks for HTML formatting,
# if you're displaying this in a web context
error_msg = "<br /><br />".join(error_msgs)

proceed_fmt = "<p style='font-family:sans-serif; color:Red;'>%s</p>"

Expand Down Expand Up @@ -558,6 +579,8 @@ def _render_streamlit_fn(state: AppState):
# ----------------------------------------------------
# we'll only allow config updates once the user has defined an allowable project name
if st_project_name:
# Check no other keys but letters, numbers
st_project_name = st_project_name.replace(' ', '_')
if st_mode == LOAD_STR:
if st_project_name not in state.initialized_projects:
# catch user error
Expand Down Expand Up @@ -658,7 +681,8 @@ def _render_streamlit_fn(state: AppState):
bytes_data = uploaded_file.read()
# name it
filename = uploaded_file.name
filepath = os.path.join(os.getcwd(), "data", filename)
filename_temp = filename.replace(".zip", '_temp.zip')
filepath = os.path.join(os.getcwd(), "data", filename_temp)
# write the content of the file to the path if it doesn't already exist
if not os.path.exists(filepath):
with open(filepath, "wb") as f:
Expand All @@ -669,9 +693,11 @@ def _render_streamlit_fn(state: AppState):
# grab keypoint names
st_keypoints = get_keypoints_from_zipfile(filepath, project_type=st_prev_format)
# update relevant vars

state.st_upload_existing_project_zippath = filepath
enter_data = True
st_mode = CREATE_STR

st.caption(
"If your zip file is larger than the 200MB limit, see the [FAQ]"
"(https://pose-app.readthedocs.io/en/latest/source/faqs.html#faq-upload-limit)",
Expand Down
10 changes: 10 additions & 0 deletions lightning_pose_app/ui/streamlit_video_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def _render_streamlit_fn(state: AppState):
# read and show the predictions labeled video
video_file = open(selected_video, "rb")
video_bytes = video_file.read()
custom_css = """
<style>
video {
width: 100% !important;
height: 50% !important;
}
</style>
"""
st.markdown(custom_css, unsafe_allow_html=True)

st.video(video_bytes)
else:
st.write("No video to preview")