Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new --save-csv argument to detect.py #12042

Merged
merged 8 commits into from
Sep 4, 2023
Merged
23 changes: 23 additions & 0 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""

import argparse
import csv
import os
import platform
import sys
Expand Down Expand Up @@ -63,6 +64,7 @@ def run(
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
view_img=False, # show results
save_txt=False, # save results to *.txt
save_csv=False, # save results in CSV format
save_conf=False, # save confidences in --save-txt labels
save_crop=False, # save cropped prediction boxes
nosave=False, # do not save images/videos
Expand Down Expand Up @@ -135,6 +137,18 @@ def run(
# Second-stage classifier (optional)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)

# Define the path for the CSV file
csv_path = save_dir / 'predictions.csv'

# Create or append to the CSV file
def write_to_csv(image_name, prediction, confidence):
data = {'Image Name': image_name, 'Prediction': prediction, 'Confidence': confidence}
with open(csv_path, mode='a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=data.keys())
if not csv_path.is_file():
writer.writeheader()
writer.writerow(data)

# Process predictions
for i, det in enumerate(pred): # per image
seen += 1
Expand Down Expand Up @@ -162,6 +176,14 @@ def run(

# Write results
for *xyxy, conf, cls in reversed(det):
c = int(cls) # integer class
label = names[c] if hide_conf else f'{names[c]}'
confidence = float(conf)
confidence_str = f'{confidence:.2f}'

if save_csv:
write_to_csv(p.name, label, confidence_str)

if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
Expand Down Expand Up @@ -229,6 +251,7 @@ def parse_opt():
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='show results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-csv', action='store_true', help='save results in CSV format')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
Expand Down