diff --git a/docs/en/datasets/detect/roboflow-100.md b/docs/en/datasets/detect/roboflow-100.md index fe54138d85f..3ca3e4f1b1e 100644 --- a/docs/en/datasets/detect/roboflow-100.md +++ b/docs/en/datasets/detect/roboflow-100.md @@ -48,6 +48,7 @@ Dataset benchmarking evaluates machine learning model performance on specific da ```python from pathlib import Path import shutil + import os from ultralytics.utils.benchmarks import RF100Benchmark # Initialize RF100Benchmark and set API key @@ -65,10 +66,10 @@ Dataset benchmarking evaluates machine learning model performance on specific da if path.exists(): # Fix YAML file and run training benchmark.fix_yaml(str(path)) - Path.cwd().system(f'yolo detect train data={path} model=yolov8s.pt epochs=1 batch=16') + os.system(f'yolo detect train data={path} model=yolov8s.pt epochs=1 batch=16') # Run validation and evaluate - Path.cwd().system(f'yolo detect val data={path} model=runs/detect/train/weights/best.pt > {val_log_file} 2>&1') + os.system(f'yolo detect val data={path} model=runs/detect/train/weights/best.pt > {val_log_file} 2>&1') benchmark.evaluate(str(path), str(val_log_file), str(eval_log_file), ind) # Remove the 'runs' directory diff --git a/docs/en/guides/index.md b/docs/en/guides/index.md index 0dd5f9db174..dca4a0bb962 100644 --- a/docs/en/guides/index.md +++ b/docs/en/guides/index.md @@ -57,6 +57,7 @@ Here's a compilation of in-depth guides to help you master different aspects of - [Speed Estimation](speed-estimation.md) 🚀 NEW: Speed estimation in computer vision relies on analyzing object motion through techniques like [object tracking](https://docs.ultralytics.com/modes/track/), crucial for applications like autonomous vehicles and traffic monitoring. - [Distance Calculation](distance-calculation.md) 🚀 NEW: Distance calculation, which involves measuring the separation between two objects within a defined space, is a crucial aspect. In the context of Ultralytics YOLOv8, the method employed for this involves using the bounding box centroid to determine the distance associated with user-highlighted bounding boxes. - [Queue Management](queue-management.md) 🚀 NEW: Queue management is the practice of efficiently controlling and directing the flow of people or tasks, often through strategic planning and technology implementation, to minimize wait times and improve overall productivity. +- [Parking Management](parking-management.md) 🚀 NEW: Parking management involves efficiently organizing and directing the flow of vehicles in parking areas, often through strategic planning and technology integration, to optimize space utilization and enhance user experience. ## Contribute to Our Guides diff --git a/docs/en/guides/parking-management.md b/docs/en/guides/parking-management.md new file mode 100644 index 00000000000..43ae97fda05 --- /dev/null +++ b/docs/en/guides/parking-management.md @@ -0,0 +1,116 @@ +--- +comments: true +description: Parking Management System Using Ultralytics YOLOv8 +keywords: Ultralytics, YOLOv8, Object Detection, Object Counting, Parking lots, Object Tracking, Notebook, IPython Kernel, CLI, Python SDK +--- + +# Parking Management using Ultralytics YOLOv8 🚀 + +## What is Parking Management System? + +Parking management with [Ultralytics YOLOv8](https://github.com/ultralytics/ultralytics/) ensures efficient and safe parking by organizing spaces and monitoring availability. YOLOv8 can improve parking lot management through real-time vehicle detection, and insights into parking occupancy. + +## Advantages of Parking Management System? + +- **Efficiency**: Parking lot management optimizes the use of parking spaces and reduces congestion. +- **Safety and Security**: Parking management using YOLOv8 improves the safety of both people and vehicles through surveillance and security measures. +- **Reduced Emissions**: Parking management using YOLOv8 manages traffic flow to minimize idle time and emissions in parking lots. + +## Real World Applications + +| Parking Management System | Parking Management System | +|:-------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| ![Parking lots Analytics Using Ultralytics YOLOv8](https://github.com/RizwanMunawar/RizwanMunawar/assets/62513924/e3d4bc3e-cf4a-4da9-b42e-0da55cc74ad6) | ![Parking management top view using Ultralytics YOLOv8](https://github.com/RizwanMunawar/RizwanMunawar/assets/62513924/fe186719-1aca-43c9-b388-1ded91280eb5) | +| Parking management Aeriel View using Ultralytics YOLOv8 | Parking management Top View using Ultralytics YOLOv8 | + + +## Parking Management System Code Workflow + +### Selection of Points + +!!! Tip "Point Selection is now Easy" + + Choosing parking points is a critical and complex task in parking management systems. Ultralytics streamlines this process by providing a tool that lets you define parking lot areas, which can be utilized later for additional processing. + +- Capture a frame from the video or camera stream where you want to manage the parking lot. +- Use the provided code to launch a graphical interface, where you can select an image and start outlining parking regions by mouse click to create polygons. + +!!! Warning "Image Size" + + Max Image Size of 1920 * 1080 supported + +```python +from ultralytics.solutions.parking_management import ParkingPtsSelection, tk +root = tk.Tk() +ParkingPtsSelection(root) +root.mainloop() +``` + +- After defining the parking areas with polygons, click `save` to store a JSON file with the data in your working directory. + +![Ultralytics YOLOv8 Points Selection Demo](https://github.com/RizwanMunawar/RizwanMunawar/assets/62513924/72737b8a-0f0f-4efb-98ad-b917a0039535) + + +### Python Code for Parking Management + +!!! Example "Parking management using YOLOv8 Example" + + === "Parking Management" + + ```python + import cv2 + from ultralytics.solutions.parking_management import ParkingManagement + + # Path to json file, that created with above point selection app + polygon_json_path = "bounding_boxes.json" + + # Video Capture + cap = cv2.VideoCapture("Path/to/video/file.mp4") + assert cap.isOpened(), "Error reading video file" + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + video_writer = cv2.VideoWriter("parking management.avi", cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + + # Initialize parking management object + management = ParkingManagement(model_path="yolov8n.pt") + + while cap.isOpened(): + ret, im0 = cap.read() + if not ret: + break + json_data = management.parking_regions_extraction(polygon_json_path) + results = management.model.track(im0, persist=True, show=False) + + if results[0].boxes.id is not None: + boxes = results[0].boxes.xyxy.cpu().tolist() + clss = results[0].boxes.cls.cpu().tolist() + management.process_data(json_data, im0, boxes, clss) + + management.display_frames(im0) + video_writer.write(im0) + + cap.release() + video_writer.release() + cv2.destroyAllWindows() + ``` + +### Optional Arguments `ParkingManagement()` + +| Name | Type | Default | Description | +|--------------------------|-------------|-------------------|-----------------------------------------------------| +| `occupied_region_color` | `RGB Color` | `(0, 255, 0)` | Parking space occupied region color | +| `available_region_color` | `RGB Color` | `(0, 0, 255)` | Parking space available region color | +| `margin` | `int` | `10` | Gap between text display for multiple classes count | +| `txt_color` | `RGB Color` | `(255, 255, 255)` | Foreground color for object counts text | +| `bg_color` | `RGB Color` | `(255, 255, 255)` | Rectangle behind text background color | + +### Arguments `model.track` + +| Name | Type | Default | Description | +|-----------|---------|----------------|-------------------------------------------------------------| +| `source` | `im0` | `None` | source directory for images or videos | +| `persist` | `bool` | `False` | persisting tracks between frames | +| `tracker` | `str` | `botsort.yaml` | Tracking method 'bytetrack' or 'botsort' | +| `conf` | `float` | `0.3` | Confidence Threshold | +| `iou` | `float` | `0.5` | IOU Threshold | +| `classes` | `list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] | +| `verbose` | `bool` | `True` | Display the object tracking results | diff --git a/docs/en/reference/solutions/parking_management.md b/docs/en/reference/solutions/parking_management.md new file mode 100644 index 00000000000..8fec60230e4 --- /dev/null +++ b/docs/en/reference/solutions/parking_management.md @@ -0,0 +1,20 @@ +--- +description: Parking management system using Ultralytics YOLO featuring cutting-edge technology for precise real-time occupancy and availability monitoring for parking lots. +keywords: Ultralytics YOLO, object tracking software, real-time counting solutions, video stream analysis, YOLOv8 object detection, AI surveillance, smart counting technology, computer vision, AI-powered tracking, object counting accuracy, video analytics tools, automated monitoring. +--- + +# Reference for `ultralytics/solutions/parking_management.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/parking_management.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/parking_management.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/parking_management.py) 🛠️. Thank you 🙏! + +

+ +## ::: ultralytics.solutions.parking_management.ParkingManagement + +

+ +## ::: ultralytics.solutions.parking_management.ParkingPtsSelection + +

diff --git a/mkdocs.yml b/mkdocs.yml index c5988ee6954..2a07c5fc41d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -158,7 +158,8 @@ nav: - datasets/index.md - Guides: - guides/index.md - - NEW 🚀 Explorer: + - New 🚀 Parking Management: guides/parking-management.md + - Explorer: - datasets/explorer/index.md - Languages: - 🇬🇧  English: https://ultralytics.com/docs/ @@ -301,6 +302,7 @@ nav: - Speed Estimation: guides/speed-estimation.md - Distance Calculation: guides/distance-calculation.md - Queue Management: guides/queue-management.md + - Parking Management: guides/parking-management.md - YOLOv5: - yolov5/index.md - Quickstart: yolov5/quickstart_tutorial.md @@ -499,6 +501,7 @@ nav: - object_counter: reference/solutions/object_counter.md - queue_management: reference/solutions/queue_management.md - speed_estimation: reference/solutions/speed_estimation.md + - parking_management: reference/solutions/parking_management.md - trackers: - basetrack: reference/trackers/basetrack.md - bot_sort: reference/trackers/bot_sort.md diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 8ec03a62893..45d7f356639 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.4" +__version__ = "8.2.5" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py index 1c71eb73301..4ff305d7937 100644 --- a/ultralytics/solutions/heatmap.py +++ b/ultralytics/solutions/heatmap.py @@ -190,9 +190,7 @@ def generate_heatmap(self, im0, tracks): for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids): # Store class info if self.names[cls] not in self.class_wise_count: - if len(self.names[cls]) > 5: - self.names[cls] = self.names[cls][:5] - self.class_wise_count[self.names[cls]] = {"in": 0, "out": 0} + self.class_wise_count[self.names[cls]] = {"IN": 0, "OUT": 0} if self.shape == "circle": center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) @@ -225,10 +223,10 @@ def generate_heatmap(self, im0, tracks): if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0: self.in_counts += 1 - self.class_wise_count[self.names[cls]]["in"] += 1 + self.class_wise_count[self.names[cls]]["IN"] += 1 else: self.out_counts += 1 - self.class_wise_count[self.names[cls]]["out"] += 1 + self.class_wise_count[self.names[cls]]["OUT"] += 1 # Count objects using line elif len(self.count_reg_pts) == 2: @@ -239,10 +237,10 @@ def generate_heatmap(self, im0, tracks): if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0: self.in_counts += 1 - self.class_wise_count[self.names[cls]]["in"] += 1 + self.class_wise_count[self.names[cls]]["IN"] += 1 else: self.out_counts += 1 - self.class_wise_count[self.names[cls]]["out"] += 1 + self.class_wise_count[self.names[cls]]["OUT"] += 1 else: for box, cls in zip(self.boxes, self.clss): @@ -264,28 +262,21 @@ def generate_heatmap(self, im0, tracks): heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX) heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap) - label = "Ultralytics Analytics \t" + labels_dict = {} for key, value in self.class_wise_count.items(): - if value["in"] != 0 or value["out"] != 0: + if value["IN"] != 0 or value["OUT"] != 0: if not self.view_in_counts and not self.view_out_counts: - label = None + continue elif not self.view_in_counts: - label += f"{str.capitalize(key)}: IN {value['in']} \t" + labels_dict[str.capitalize(key)] = f"OUT {value['OUT']}" elif not self.view_out_counts: - label += f"{str.capitalize(key)}: OUT {value['out']} \t" + labels_dict[str.capitalize(key)] = f"IN {value['IN']}" else: - label += f"{str.capitalize(key)}: IN {value['in']} OUT {value['out']} \t" + labels_dict[str.capitalize(key)] = f"IN {value['IN']} OUT {value['OUT']}" - label = label.rstrip() - label = label.split("\t") - - if self.count_reg_pts is not None and label is not None: - self.annotator.display_counts( - counts=label, - count_txt_color=self.count_txt_color, - count_bg_color=self.count_bg_color, - ) + if labels_dict is not None: + self.annotator.display_analytics(self.im0, labels_dict, self.count_txt_color, self.count_bg_color, 10) self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0) diff --git a/ultralytics/solutions/object_counter.py b/ultralytics/solutions/object_counter.py index e7174b0d6ff..9dc4ba63fa6 100644 --- a/ultralytics/solutions/object_counter.py +++ b/ultralytics/solutions/object_counter.py @@ -181,9 +181,7 @@ def extract_and_process_tracks(self, tracks): # Store class info if self.names[cls] not in self.class_wise_count: - if len(self.names[cls]) > 5: - self.names[cls] = self.names[cls][:5] - self.class_wise_count[self.names[cls]] = {"in": 0, "out": 0} + self.class_wise_count[self.names[cls]] = {"IN": 0, "OUT": 0} # Draw Tracks track_line = self.track_history[track_id] @@ -210,10 +208,10 @@ def extract_and_process_tracks(self, tracks): if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0: self.in_counts += 1 - self.class_wise_count[self.names[cls]]["in"] += 1 + self.class_wise_count[self.names[cls]]["IN"] += 1 else: self.out_counts += 1 - self.class_wise_count[self.names[cls]]["out"] += 1 + self.class_wise_count[self.names[cls]]["OUT"] += 1 # Count objects using line elif len(self.reg_pts) == 2: @@ -224,33 +222,26 @@ def extract_and_process_tracks(self, tracks): if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0: self.in_counts += 1 - self.class_wise_count[self.names[cls]]["in"] += 1 + self.class_wise_count[self.names[cls]]["IN"] += 1 else: self.out_counts += 1 - self.class_wise_count[self.names[cls]]["out"] += 1 + self.class_wise_count[self.names[cls]]["OUT"] += 1 - label = "Ultralytics Analytics \t" + labels_dict = {} for key, value in self.class_wise_count.items(): - if value["in"] != 0 or value["out"] != 0: + if value["IN"] != 0 or value["OUT"] != 0: if not self.view_in_counts and not self.view_out_counts: - label = None + continue elif not self.view_in_counts: - label += f"{str.capitalize(key)}: IN {value['in']} \t" + labels_dict[str.capitalize(key)] = f"OUT {value['OUT']}" elif not self.view_out_counts: - label += f"{str.capitalize(key)}: OUT {value['out']} \t" + labels_dict[str.capitalize(key)] = f"IN {value['IN']}" else: - label += f"{str.capitalize(key)}: IN {value['in']} OUT {value['out']} \t" + labels_dict[str.capitalize(key)] = f"IN {value['IN']} OUT {value['OUT']}" - label = label.rstrip() - label = label.split("\t") - - if label is not None: - self.annotator.display_counts( - counts=label, - count_txt_color=self.count_txt_color, - count_bg_color=self.count_bg_color, - ) + if labels_dict is not None: + self.annotator.display_analytics(self.im0, labels_dict, self.count_txt_color, self.count_bg_color, 10) def display_frames(self): """Display frame.""" diff --git a/ultralytics/solutions/parking_management.py b/ultralytics/solutions/parking_management.py new file mode 100644 index 00000000000..98cff90c326 --- /dev/null +++ b/ultralytics/solutions/parking_management.py @@ -0,0 +1,235 @@ +import json +from tkinter import filedialog, messagebox + +import cv2 +import numpy as np +from PIL import Image, ImageTk + +from ultralytics.utils.checks import check_imshow, check_requirements +from ultralytics.utils.plotting import Annotator + +check_requirements("tkinter") +import tkinter as tk + + +class ParkingPtsSelection: + def __init__(self, master): + # Initialize window and widgets. + self.master = master + master.title("Ultralytics Parking Zones Points Selector") + self.initialize_ui() + + # Initialize properties + self.image_path = None + self.image = None + self.canvas_image = None + self.canvas = None + self.bounding_boxes = [] + self.current_box = [] + self.img_width = 0 + self.img_height = 0 + + # Constants + self.canvas_max_width = 1280 + self.canvas_max_height = 720 + + def initialize_ui(self): + """Setup UI components.""" + # Setup buttons + button_frame = tk.Frame(self.master) + button_frame.pack(side=tk.TOP) + + tk.Button(button_frame, text="Upload Image", command=self.upload_image).grid(row=0, column=0) + tk.Button(button_frame, text="Remove Last BBox", command=self.remove_last_bounding_box).grid(row=0, column=1) + tk.Button(button_frame, text="Save", command=self.save_to_json).grid(row=0, column=2) + + # Setup canvas for image display + self.canvas = tk.Canvas(self.master, bg="white") + self.canvas.pack(side=tk.BOTTOM) + self.canvas.bind("", self.on_canvas_click) + + def upload_image(self): + """Upload an image and resize it to fit canvas.""" + self.image_path = filedialog.askopenfilename(filetypes=[("Image Files", "*.png;*.jpg;*.jpeg")]) + if not self.image_path: + return + + self.image = Image.open(self.image_path) + self.img_width, self.img_height = self.image.size + + # Calculate the aspect ratio and resize image + aspect_ratio = self.img_width / self.img_height + if aspect_ratio > 1: + # Landscape orientation + canvas_width = min(self.canvas_max_width, self.img_width) + canvas_height = int(canvas_width / aspect_ratio) + else: + # Portrait orientation + canvas_height = min(self.canvas_max_height, self.img_height) + canvas_width = int(canvas_height * aspect_ratio) + + self.canvas.config(width=canvas_width, height=canvas_height) + resized_image = self.image.resize((canvas_width, canvas_height), Image.LANCZOS) + self.canvas_image = ImageTk.PhotoImage(resized_image) + self.canvas.create_image(0, 0, anchor=tk.NW, image=self.canvas_image) + + # Reset bounding boxes and current box + self.bounding_boxes = [] + self.current_box = [] + + def on_canvas_click(self, event): + """Handle mouse clicks on canvas to create points for bounding boxes.""" + self.current_box.append((event.x, event.y)) + + if len(self.current_box) == 4: + self.bounding_boxes.append(self.current_box) + self.draw_bounding_box(self.current_box) + self.current_box = [] + + def draw_bounding_box(self, box): + """Draw bounding box on canvas.""" + for i in range(4): + x1, y1 = box[i] + x2, y2 = box[(i + 1) % 4] + self.canvas.create_line(x1, y1, x2, y2, fill="blue", width=2) + + def remove_last_bounding_box(self): + """Remove the last drawn bounding box from canvas.""" + if self.bounding_boxes: + self.bounding_boxes.pop() # Remove the last bounding box + self.canvas.delete("all") # Clear the canvas + self.canvas.create_image(0, 0, anchor=tk.NW, image=self.canvas_image) # Redraw the image + + # Redraw all bounding boxes + for box in self.bounding_boxes: + self.draw_bounding_box(box) + + messagebox.showinfo("Success", "Last bounding box removed.") + else: + messagebox.showwarning("Warning", "No bounding boxes to remove.") + + def save_to_json(self): + canvas_width, canvas_height = self.canvas.winfo_width(), self.canvas.winfo_height() + width_scaling_factor = self.img_width / canvas_width + height_scaling_factor = self.img_height / canvas_height + bounding_boxes_data = [] + for box in self.bounding_boxes: + print("Bounding Box ", bounding_boxes_data) + rescaled_box = [] + for x, y in box: + rescaled_x = int(x * width_scaling_factor) + rescaled_y = int(y * height_scaling_factor) + rescaled_box.append((rescaled_x, rescaled_y)) + bounding_boxes_data.append({"points": rescaled_box}) + with open("bounding_boxes.json", "w") as json_file: + json.dump(bounding_boxes_data, json_file, indent=4) + + messagebox.showinfo("Success", "Bounding boxes saved to bounding_boxes.json") + + +class ParkingManagement: + def __init__( + self, + model_path, + txt_color=(0, 0, 0), + bg_color=(255, 255, 255), + occupied_region_color=(0, 255, 0), + available_region_color=(0, 0, 255), + margin=10, + ): + # Model path and initialization + self.model_path = model_path + self.model = self.load_model() + + # Labels dictionary + self.labels_dict = {"Occupancy": 0, "Available": 0} + + # Visualization details + self.margin = margin + self.bg_color = bg_color + self.txt_color = txt_color + self.occupied_region_color = occupied_region_color + self.available_region_color = available_region_color + + self.window_name = "Ultralytics YOLOv8 Parking Management System" + # Check if environment support imshow + self.env_check = check_imshow(warn=True) + + def load_model(self): + """Load the Ultralytics YOLOv8 model for inference and analytics.""" + from ultralytics import YOLO + + self.model = YOLO(self.model_path) + return self.model + + def parking_regions_extraction(self, json_file): + """ + Extract parking regions from json file. + + Args: + json_file (str): file that have all parking slot points + """ + + with open(json_file, "r") as json_file: + json_data = json.load(json_file) + return json_data + + def process_data(self, json_data, im0, boxes, clss): + """ + Process the model data for parking lot management. + + Args: + json_data (str): json data for parking lot management + im0 (ndarray): inference image + boxes (list): bounding boxes data + clss (list): bounding boxes classes list + Returns: + filled_slots (int): total slots that are filled in parking lot + empty_slots (int): total slots that are available in parking lot + """ + annotator = Annotator(im0) + total_slots, filled_slots = len(json_data), 0 + empty_slots = total_slots + + for region in json_data: + points = region["points"] + points_array = np.array(points, dtype=np.int32).reshape((-1, 1, 2)) + region_occupied = False + + for box, cls in zip(boxes, clss): + x_center = int((box[0] + box[2]) / 2) + y_center = int((box[1] + box[3]) / 2) + text = f"{self.model.names[int(cls)]}" + + annotator.display_objects_labels( + im0, text, self.txt_color, self.bg_color, x_center, y_center, self.margin + ) + dist = cv2.pointPolygonTest(points_array, (x_center, y_center), False) + if dist >= 0: + region_occupied = True + break + + color = self.occupied_region_color if region_occupied else self.available_region_color + cv2.polylines(im0, [points_array], isClosed=True, color=color, thickness=2) + if region_occupied: + filled_slots += 1 + empty_slots -= 1 + + self.labels_dict["Occupancy"] = filled_slots + self.labels_dict["Available"] = empty_slots + + annotator.display_analytics(im0, self.labels_dict, self.txt_color, self.bg_color, self.margin) + + def display_frames(self, im0): + """ + Display frame. + + Args: + im0 (ndarray): inference image + """ + if self.env_check: + cv2.namedWindow(self.window_name) + cv2.imshow(self.window_name, im0) + # Break Window + if cv2.waitKey(1) & 0xFF == ord("q"): + return diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index cbdf78739d4..5f01563a98d 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -419,51 +419,63 @@ def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), lineType=cv2.LINE_AA, ) - def display_counts(self, counts=None, count_bg_color=(0, 0, 0), count_txt_color=(255, 255, 255)): + ### Parking management utils + def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin): """ - Display counts on im0 with text background and border. + Display the bounding boxes labels in parking management app. Args: - counts (str): objects count data - count_bg_color (RGB Color): counts highlighter color - count_txt_color (RGB Color): counts display color + im0 (ndarray): inference image + text (str): object/class name + txt_color (bgr color): display color for text foreground + bg_color (bgr color): display color for text background + x_center (float): x position center point for bounding box + y_center (float): y position center point for bounding box + margin (int): gap between text and rectangle for better display """ - tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1 - tf = max(tl - 1, 1) + text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] + text_x = x_center - text_size[0] // 2 + text_y = y_center + text_size[1] // 2 - t_sizes = [cv2.getTextSize(str(count), 0, fontScale=self.sf, thickness=self.tf)[0] for count in counts] + rect_x1 = text_x - margin + rect_y1 = text_y - text_size[1] - margin + rect_x2 = text_x + text_size[0] + margin + rect_y2 = text_y + margin + cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1) + cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA) - max_text_width = max([size[0] for size in t_sizes]) - max_text_height = max([size[1] for size in t_sizes]) - - text_x = self.im.shape[1] - int(self.im.shape[1] * 0.025 + max_text_width) - text_y = int(self.im.shape[0] * 0.025) - - for i, count in enumerate(counts): - text_x_pos = text_x - text_y_pos = text_y + i * (max_text_height + 25 * tf) - - # Draw the border - cv2.rectangle( - self.im, - (text_x_pos - (10 * tf), text_y_pos - (10 * tf)), - (text_x_pos + max_text_width + (10 * tf), text_y_pos + max_text_height + (10 * tf)), - count_bg_color, - -1, - ) + # Parking lot and object counting app + def display_analytics(self, im0, text, txt_color, bg_color, margin): + """ + Display the overall statistics for parking lots + Args: + im0 (ndarray): inference image + text (dict): labels dictionary + txt_color (bgr color): display color for text foreground + bg_color (bgr color): display color for text background + margin (int): gap between text and rectangle for better display + """ - # Draw the count text + horizontal_gap = int(im0.shape[1] * 0.02) + vertical_gap = int(im0.shape[0] * 0.01) + + text_y_offset = 0 + + for label, value in text.items(): + txt = f"{label}: {value}" + text_size = cv2.getTextSize(txt, 0, int(self.sf * 1.5), int(self.tf * 1.5))[0] + text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap + text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap + rect_x1 = text_x - margin * 2 + rect_y1 = text_y - text_size[1] - margin * 2 + rect_x2 = text_x + text_size[0] + margin * 2 + rect_y2 = text_y + margin * 2 + cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1) cv2.putText( - self.im, - str(count), - (text_x_pos, text_y_pos + max_text_height), - 0, - fontScale=self.sf, - color=count_txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, + im0, txt, (text_x, text_y), 0, int(self.sf * 1.5), txt_color, int(self.tf * 1.5), lineType=cv2.LINE_AA ) + text_y_offset = rect_y2 @staticmethod def estimate_pose_angle(a, b, c):