Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to do object detection on video or streaming data #2045

Closed
Surayuth opened this issue Jan 26, 2021 · 24 comments
Closed

How to do object detection on video or streaming data #2045

Surayuth opened this issue Jan 26, 2021 · 24 comments
Labels
question Further information is requested

Comments

@Surayuth
Copy link

❔Question

I'm new to this framework. Can someone guide me on how to do object detection on video and streaming data using yolov5. My problem is I want to detect objects from a video game in real-time(when I'm playing the game). Do I have to capture a screen every second and then pass the screen to the model?

Thanks.

Additional context

@Surayuth Surayuth added the question Further information is requested label Jan 26, 2021
@glenn-jocher
Copy link
Member

@Surayuth see README for inference examples on videos and streaming sources:
https://github.com/ultralytics/yolov5#inference

Inference

detect.py runs inference on a variety of sources, downloading models automatically from the latest YOLOv5 release and saving results to runs/detect.

$ python detect.py --source 0  # webcam
                            file.jpg  # image 
                            file.mp4  # video
                            path/  # directory
                            path/*.jpg  # glob
                            rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa  # rtsp stream
                            rtmp://192.168.1.105/live/test  # rtmp stream
                            http://112.50.243.8/PLTV/88888888/224/3221225900/1.m3u8  # http stream

@meadlai
Copy link

meadlai commented Nov 6, 2021

@glenn-jocher, thank you for your guide. With the webcam as source, is there any way to show/display the detect result directly? And get the coordinate of the objects?

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 8, 2021

@meadlai python detect.py --source 0 displays the results in realtime already.

For returning results in a python environment you probably want to use a YOLOv5 PyTorch Hub model:

import torch

# Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')

# Image
img = 'https://ultralytics.com/images/zidane.jpg'

# Inference
results = model(img)

results.pandas().xyxy[0]
#      xmin    ymin    xmax   ymax  confidence  class    name
# 0  749.50   43.50  1148.0  704.5    0.874023      0  person
# 1  433.50  433.50   517.5  714.5    0.687988     27     tie
# 2  114.75  195.75  1095.0  708.0    0.624512      0  person
# 3  986.00  304.00  1028.0  420.0    0.286865     27     tie

See PyTorch Hub tutorial for details:

YOLOv5 Tutorials

@Neel7317
Copy link

@meadlai here is the code what you are looking for..

import torch
import numpy as np
import cv2
from time import time

class OD:

def __init__(self, capture_index, model_name):
    """
    Initializes the class with youtube url and output file.
    :param url: Has to be as youtube URL,on which prediction is made.
    :param out_file: A valid output file name.
    """
    self.capture_index = capture_index
    self.model = self.load_model(model_name)
    self.classes = self.model.names
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Using Device: ", self.device)

def get_video_capture(self):
    """
    Creates a new video streaming object to extract video frame by frame to make prediction on.
    :return: opencv2 video capture object, with lowest quality frame available for video.
    """
  
    return cv2.VideoCapture(self.capture_index)

def load_model(self, model_name):
    """
    Loads Yolo5 model from pytorch hub.
    :return: Trained Pytorch model.
    """
    if model_name:
        model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_name, force_reload=True)
    else:
        model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
    return model

def score_frame(self, frame):
    """
    Takes a single frame as input, and scores the frame using yolo5 model.
    :param frame: input frame in numpy/list/tuple format.
    :return: Labels and Coordinates of objects detected by model in the frame.
    """
    self.model.to(self.device)
    frame = [frame]
    results = self.model(frame)
    labels, cord = results.xyxyn[0][:, -1], results.xyxyn[0][:, :-1]
    return labels, cord

def class_to_label(self, x):
    """
    For a given label value, return corresponding string label.
    :param x: numeric label
    :return: corresponding string label
    """
    return self.classes[int(x)]

def plot_boxes(self, results, frame):
    """
    Takes a frame and its results as input, and plots the bounding boxes and label on to the frame.
    :param results: contains labels and coordinates predicted by model on the given frame.
    :param frame: Frame which has been scored.
    :return: Frame with bounding boxes and labels ploted on it.
    """
    labels, cord = results
    n = len(labels)
    x_shape, y_shape = frame.shape[1], frame.shape[0]
    for i in range(n):
        row = cord[i]
        if row[4] >= 0.3:
            x1, y1, x2, y2 = int(row[0]*x_shape), int(row[1]*y_shape), int(row[2]*x_shape), int(row[3]*y_shape)
            bgr = (0, 255, 0)
            cv2.rectangle(frame, (x1, y1), (x2, y2), bgr, 2)
            cv2.putText(frame, self.class_to_label(labels[i]), (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.9, bgr, 2)

    return frame

def __call__(self):
    """
    This function is called when class is executed, it runs the loop to read the video frame by frame,
    and write the output into a new file.
    :return: void
    """
    cap = self.get_video_capture()
    assert cap.isOpened()
  
    while True:
      
        ret, frame = cap.read()
        assert ret
        
        frame = cv2.resize(frame, (640,640))
        
        start_time = time()
        results = self.score_frame(frame)
        frame = self.plot_boxes(results, frame)
        
        end_time = time()
        fps = 1/np.round(end_time - start_time, 2)
        #print(f"Frames Per Second : {fps}")
         
        cv2.putText(frame, f'FPS: {int(fps)}', (20,70), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0,255,0), 2)
        
        cv2.imshow('YOLOv5 Detection', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
  
    cap.release()

Create a new object and execute.

detector = OD(capture_index=0, model_name='320_yolo_400.pt')
detector()

@gustavozantut
Copy link

https://github.com/gustavozantut/yolov5_live_results streaming output frames results in real time

@wb-08
Copy link

wb-08 commented Sep 10, 2022

import cv2
import torch
from PIL import Image
from mss import mss
import numpy as np

model = torch.hub.load("yolov5", 'custom', path="/yolov5/best.pt", source='local')

sct = mss()

while 1:
    w, h = 1920, 1080
    monitor = {'top': 0, 'left': 0, 'width': w, 'height': h}
    img = Image.frombytes('RGB', (w, h), sct.grab(monitor).rgb)
    screen = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
    # set the model use the screen
    result = model(screen, size=640)
    print(result)
    cv2.imshow('Screen', result.render()[0])

    if cv2.waitKey(25) & 0xFF == ord('q'):
        cv2.destroyAllWindows()
        break

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 10, 2022

@wb-08 how does mss() compare to our official screenshot example?

Screenshot 2022-09-10 at 19 52 18

@wb-08
Copy link

wb-08 commented Sep 11, 2022

@glenn-jocher , but it doesn't work correctly for the video

@glenn-jocher
Copy link
Member

@wb-08 yes the example is for an image, but that can be dropped into a for loop obviously and turned into a video.

@wb-08
Copy link

wb-08 commented Sep 11, 2022

@glenn-jocher , maybe into while loop?
In any case, you can try my solution and the official one and see that my solution works better:)

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 11, 2022

@wb-08 got it. I tried it out, mss seems to be much faster than the PIL image grab method. This is about 30 ms vs 400 ms for tutorial method:

import time

import cv2
import mss
import numpy

with mss.mss() as sct:
    # Part of the screen to capture
    monitor = sct.monitors[0]

    while "Screen capturing":
        last_time = time.time()

        # Get raw pixels from the screen, save it to a Numpy array
        img = numpy.array(sct.grab(monitor))

        # Display the picture
        cv2.imshow("OpenCV/Numpy normal", img)

        # Display the picture in grayscale
        # cv2.imshow('OpenCV/Numpy grayscale',
        #            cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY))

        print("fps: {}".format(1 / (time.time() - last_time)))

        # Press "q" to quit
        if cv2.waitKey(25) & 0xFF == ord("q"):
            cv2.destroyAllWindows()
            break

@glenn-jocher
Copy link
Member

@AyushExel this mss screenshot loader works really well, we should integrate into StreamLoader with detect.py reserved source name like this:

python detect.py --source screen

@zombob
Copy link
Contributor

zombob commented Sep 21, 2022

@AyushExel this mss screenshot loader works really well, we should integrate into StreamLoader with detect.py reserved source name like this:

python detect.py --source screen

may be better:
python detect.py --source screen top left width height

@glenn-jocher
Copy link
Member

@zombob not a bad idea. We're super busy and haven't had time to work on this feature yet, but if you'd like to help with a PR that would be great!

@zombob
Copy link
Contributor

zombob commented Sep 22, 2022

@zombob not a bad idea. We're super busy and haven't had time to work on this feature yet, but if you'd like to help with a PR that would be great!

Fixed this , PR link: #9542
use it like:

python .\detect.py --source screen   # default full screen(0)
python .\detect.py --source "screen 2"   # 2nd screen only if you have multiple monitors, can specify screen number
python .\detect.py --source "screen 500 600 256 256"   # Specify top, left, width and height
python .\detect.py --source "screen 1 500 100 256 256"   # Specify screen nunmber, top, left, width and height

@DABHIHARDIK
Copy link

if I want to detection in youtube live stream video or RTMP server then what should i do ??

@gustavozantut
Copy link

gustavozantut commented May 2, 2023

if I want to detection in youtube live stream video or RTMP server then what should i do ??

Just specify "--source {link_here}" when calling detect.py. YouTube works, rtsp too , not sure about rtmp.

@glenn-jocher
Copy link
Member

@gustavozantut yes, you can use the YOLOv5 detector to perform object detection on youtube live-streaming videos or RTMP servers. You just need to specify the source URL when calling detect.py using the --source flag, like the following example:

python detect.py --source "http://youtube.com/watch?v=abcdefghijk"

Note that you should specify the streaming link and not the webpage's link. For RTMP servers, I'm not completely sure, but you may be able to use the same command, just replace the source URL with your RTMP server URL.

@fatmaboodai
Copy link

abcdefghijk"

Hello,
Is it possible to have the detection happen in real time as an overlay within the chrome browser window? Like i can pass the screen recording to the model and the boxes frame as an overlay to the tab I'm viewing inside the chrome browser

@glenn-jocher
Copy link
Member

@fatmaboodai yes, it is possible to achieve real-time object detection as an overlay within a Chrome browser window by using the YOLOv5 model. You can capture the screen recording and pass it to the model for object detection. Then, you can use browser-based technologies such as WebRTC, HTML5 canvas, or WebGL to overlay the detected boxes onto the tab you are viewing inside the Chrome browser. This would involve a combination of capturing the screen, processing the frames with YOLOv5, and rendering the results as an overlay in the browser window.

@fatmaboodai
Copy link

@fatmaboodai yes, it is possible to achieve real-time object detection as an overlay within a Chrome browser window by using the YOLOv5 model. You can capture the screen recording and pass it to the model for object detection. Then, you can use browser-based technologies such as WebRTC, HTML5 canvas, or WebGL to overlay the detected boxes onto the tab you are viewing inside the Chrome browser. This would involve a combination of capturing the screen, processing the frames with YOLOv5, and rendering the results as an overlay in the browser window.

Just another question:
I already have a model trained on a custom dataset that is yolov8 based can I achieve the same goal using it?

@glenn-jocher
Copy link
Member

@fatmaboodai Yes, you can achieve the same goal using your custom YOLOv8-based model. You can capture the screen recording, pass it to your custom model for object detection, and then use browser-based technologies to overlay the detected boxes onto the Chrome browser window. The process would be similar to what I described earlier, but using your custom YOLOv8-based model for object detection instead of YOLOv5.

@fatmaboodai
Copy link

Thank you so much i really appreciate it🙏🏽🙏🏽

@glenn-jocher
Copy link
Member

@fatmaboodai you're welcome! If you have any more questions or need further assistance, feel free to ask. Good luck with your project!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

9 participants