Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 4, 2023
1 parent fc3be9e commit 9831e83
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions classify/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

import argparse
import csv
import json
import io
import json
import os
import platform
import sys
Expand Down Expand Up @@ -109,8 +109,8 @@ def run(
# Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
with open(save_dir / 'predictions.csv', 'w', newline='') as csvfile: # Open CSV file for saving all predictions

with open(save_dir / 'predictions.csv', 'w', newline='') as csvfile: # Open CSV file for saving all predictions
csv_output = csv.DictWriter(csvfile, fieldnames=['path', 'label', 'confidence', 'top_5_predicted'])
csv_output.writeheader()

Expand All @@ -127,8 +127,8 @@ def run(

# Post-process
with dt[2]:
pred = F.softmax(results, dim=1) # probabilities
pred = F.softmax(results, dim=1) # probabilities

# Process predictions
for i, prob in enumerate(pred): # per image
seen += 1
Expand All @@ -140,7 +140,8 @@ def run(

p = Path(p) # to Path
save_path = str(save_dir / p.name) # im.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}'
) # im.txt

s += '%gx%g ' % im.shape[2:] # print string
annotator = Annotator(im0, example=str(names), pil=True)
Expand All @@ -155,8 +156,7 @@ def run(
'path': path,
'top_5_predicted': json.dumps([(names[j], prob[j].item()) for j in top5i]),
'label': names[top_pred_index],
'confidence': f'{prob[top_pred_index]:.2f}'
})
'confidence': f'{prob[top_pred_index]:.2f}'})

# Write results
text = '\n'.join(f'{prob[j]:.2f} {names[j]}' for j in top5i)
Expand Down

0 comments on commit 9831e83

Please sign in to comment.