Skip to content

Commit

Permalink
fix rgb to bgr and remove check
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Dupont committed Apr 2, 2023
1 parent 265a828 commit d2be717
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions src/super_gradients/training/utils/videos.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
from typing import List, Optional, Tuple
import cv2
import os

import numpy as np

SUPPORTED_FORMATS = (".avi", ".mp4", ".mov", ".wmv", ".mkv")


__all__ = ["load_video", "save_video"]


def load_video(file_path: str, max_frames: Optional[int] = None) -> List[np.ndarray]:
def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
"""Open a video file and extract each frame into numpy array.
:param file_path: Path to the video file.
:param max_frames: Optional, maximum number of frames to extract.
:return: Frames representing the video, each in (H, W, C).
:return:
- Frames representing the video, each in (H, W, C), RGB.
- Frames per Second (FPS).
"""
cap = _open_video(file_path)
frames = _extract_frames(cap, max_frames)
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return frames
return frames, fps


def _open_video(file_path: str) -> cv2.VideoCapture:
Expand All @@ -29,10 +29,6 @@ def _open_video(file_path: str) -> cv2.VideoCapture:
:param file_path: Path to the video file
:return: Opened video capture object
"""
ext = os.path.splitext(file_path)[-1].lower()
if ext not in SUPPORTED_FORMATS:
raise RuntimeError(f"Not supported video format {ext}. Supported formats: {SUPPORTED_FORMATS}")

cap = cv2.VideoCapture(file_path)
if not cap.isOpened():
raise ValueError(f"Failed to open video file: {file_path}")
Expand All @@ -44,15 +40,15 @@ def _extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) ->
:param cap: Opened video capture object.
:param max_frames: Optional maximum number of frames to extract.
:return: Frames representing the video, each in (H, W, C).
:return: Frames representing the video, each in (H, W, C), RGB.
"""
frames = []

while max_frames != len(frames):
frame_read_success, frame = cap.read()
if not frame_read_success:
break
frames.append(frame)
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

return frames

Expand All @@ -61,7 +57,7 @@ def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally.
:param output_path: Where the video will be saved
:param frames: Frames representing the video, each in (H, W, C). Note that all the frames are expected to have the same shape.
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:param fps: Frames per second
"""
video_height, video_width = _validate_frames(frames)
Expand All @@ -74,17 +70,15 @@ def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
)

for frame in frames:
if frame.ndim == 2:
frame = frame[:, :, np.newaxis]
video_writer.write(frame)
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

video_writer.release()


def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]:
"""Validate the frames to make sure that every frame has the same size and includes the channel dimension. (i.e. (H, W, C))
:param frames: Frames representing the video, each in (H, W, C). Note that all the frames are expected to have the same shape.
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:return: (Height, Weight) of the video.
"""
min_height = min(frame.shape[0] for frame in frames)
Expand Down

0 comments on commit d2be717

Please sign in to comment.