Skip to content

Commit

Permalink
Code Refactor for Speed and Readability (#55)
Browse files Browse the repository at this point in the history
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
  • Loading branch information
3 people authored Jun 9, 2024
1 parent 0b4f85b commit c6011a0
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 166 deletions.
26 changes: 14 additions & 12 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
if platform == "darwin": # macos
parser.add_argument("-image_folder", type=str, default="/Users/glennjocher/Downloads/DATA/xview/train_images/5.tif")
parser.add_argument("-output_folder", type=str, default="./output_xview", help="path to outputs")
cuda = False # torch.cuda.is_available()
else: # gcp
# cd yolo && python3 detect.py -secondary_classifier 1
parser.add_argument("-image_folder", type=str, default="../train_images/5.tif", help="path to images")
parser.add_argument("-output_folder", type=str, default="../output", help="path to outputs")
cuda = False

cuda = False # torch.cuda.is_available()
parser.add_argument("-plot_flag", type=bool, default=True)
parser.add_argument("-secondary_classifier", type=bool, default=False)
parser.add_argument("-cfg", type=str, default="cfg/c60_a30symmetric.cfg", help="cfg file path")
Expand All @@ -35,9 +33,9 @@
def detect(opt):
"""Detects objects in images using Darknet model, optionally uses a secondary classifier, and performs NMS."""
if opt.plot_flag:
os.system("rm -rf " + opt.output_folder + "_img")
os.makedirs(opt.output_folder + "_img", exist_ok=True)
os.system("rm -rf " + opt.output_folder)
os.system(f"rm -rf {opt.output_folder}_img")
os.makedirs(f"{opt.output_folder}_img", exist_ok=True)
os.system(f"rm -rf {opt.output_folder}")
os.makedirs(opt.output_folder, exist_ok=True)
device = torch.device("cuda:0" if cuda else "cpu")

Expand Down Expand Up @@ -150,7 +148,7 @@ def detect(opt):
# pred[:, 1] += y1
# preds.append(pred.unsqueeze(0))

if len(preds) > 0:
if preds:
detections = non_max_suppression(
torch.cat(preds, 1), opt.conf_thres, opt.nms_thres, mat_priors, img, model2, device
)
Expand All @@ -163,7 +161,7 @@ def detect(opt):
# Bounding-box colors
color_list = [[random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] for _ in range(len(classes))]

if len(img_detections) == 0:
if not img_detections:
return

# Iterate through images and save plot of detections
Expand All @@ -187,10 +185,10 @@ def detect(opt):

# write results to .txt file
results_path = os.path.join(opt.output_folder, path.split("/")[-1])
if os.path.isfile(results_path + ".txt"):
os.remove(results_path + ".txt")
if os.path.isfile(f"{results_path}.txt"):
os.remove(f"{results_path}.txt")

results_img_path = os.path.join(opt.output_folder + "_img", path.split("/")[-1])
results_img_path = os.path.join(f"{opt.output_folder}_img", path.split("/")[-1])
with open(results_path.replace(".bmp", ".tif") + ".txt", "a") as file:
for i in unique_classes:
n = (detections[:, -1].cpu() == i).sum()
Expand Down Expand Up @@ -224,7 +222,11 @@ def detect(opt):
if opt.plot_flag:
from scoring import score

score.score(opt.output_folder + "/", "/Users/glennjocher/Downloads/DATA/xview/xView_train.geojson", ".")
score.score(
f"{opt.output_folder}/",
"/Users/glennjocher/Downloads/DATA/xview/xView_train.geojson",
".",
)


class ConvNetb(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def create_modules(module_defs):

elif module_def["type"] == "route":
layers = [int(x) for x in module_def["layers"].split(",")]
filters = sum([output_filters[layer_i] for layer_i in layers])
filters = sum(output_filters[layer_i] for layer_i in layers)
modules.add_module("route_%d" % i, EmptyLayer())

elif module_def["type"] == "shortcut":
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self, anchors, nC, img_dim, anchor_idxs):
"""Initializes YOLO layer with given anchors, number of classes, image dimensions, and anchor indexes."""
super(YOLOLayer, self).__init__()

anchors = [(a_w, a_h) for a_w, a_h in anchors] # (pixels)
anchors = list(anchors)
nA = len(anchors)

self.anchors = anchors
Expand Down Expand Up @@ -165,7 +165,7 @@ def forward(self, p, targets=None, requestPrecision=False, weight=None, epoch=No

# Mask outputs to ignore non-existing objects (but keep confidence predictions)
nM = mask.sum().float()
nGT = sum([len(x) for x in targets])
nGT = sum(len(x) for x in targets)
if nM > 0:
# wC = weight[torch.argmax(tcls, 1)] # weight class
# wC /= sum(wC)
Expand Down
33 changes: 12 additions & 21 deletions scoring/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

def safe_divide(numerator, denominator):
"""Computes the safe division to avoid the divide by zero problem."""
if denominator == 0:
return 0
return numerator / denominator
return 0 if denominator == 0 else numerator / denominator


def compute_statistics_given_rectangle_matches(groundtruth_rects_matched, rects_matched):
Expand Down Expand Up @@ -94,13 +92,16 @@ def compute_average_precision_recall_given_precision_recall_dict(precision_recal

def convert_to_rectangle_list(coordinates):
"""Converts the coordinates in a list to the Rectangle list."""
rectangle_list = []
number_of_rects = int(len(coordinates) / 4)
for i in range(number_of_rects):
rectangle_list.append(
Rectangle(coordinates[4 * i], coordinates[4 * i + 1], coordinates[4 * i + 2], coordinates[4 * i + 3])
number_of_rects = len(coordinates) // 4
return [
Rectangle(
coordinates[4 * i],
coordinates[4 * i + 1],
coordinates[4 * i + 2],
coordinates[4 * i + 3],
)
return rectangle_list
for i in range(number_of_rects)
]


def compute_average_precision_recall(groundtruth_coordinates, coordinates, iou_threshold):
Expand Down Expand Up @@ -149,18 +150,8 @@ def compute_average_precision_recall(groundtruth_coordinates, coordinates, iou_t
rects = convert_to_rectangle_list(coordinates)
matching = Matching(groundtruth_rects, rects)

image_statistics_list = []
groundtruth_rects_matched, rects_matched = matching.matching_by_greedy_assignment(iou_threshold)

image_statistics = compute_statistics_given_rectangle_matches(groundtruth_rects_matched, rects_matched)
image_statistics_list.append(image_statistics)

# Compute the precision and recall under this iou_threshold.
precision_recall = compute_precision_recall_given_image_statistics_list(iou_threshold, image_statistics_list)

# Compute the average_precision and average_recall.
# average_precision, average_recall = (
# compute_average_precision_recall_given_precision_recall_dict(
# precision_recall_dict))

return precision_recall
image_statistics_list = [image_statistics]
return compute_precision_recall_given_image_statistics_list(iou_threshold, image_statistics_list)
15 changes: 8 additions & 7 deletions scoring/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _compute_iou_from_rectangle_pairs(self):

self.iou_rectangle_pair_indices_ = defaultdict(list)

if not (n == 0 or m == 0):
if n != 0 and m != 0:
mat2 = np.array([j.coords for j in self.groundtruth_rects_])
mat1 = np.array([j.coords for j in self.rects_])
# i,j axes correspond to #boxes, #coords per rect
Expand Down Expand Up @@ -88,19 +88,20 @@ def _compute_iou_from_rectangle_pairs(self):

def greedy_match(self, iou_threshold):
"""Performs greedy matching of rectangles based on IOU threshold, returning matched indices."""
gt_rects_matched = [False for gt_index in range(self.m)]
rects_matched = [False for r_index in range(self.n)]
gt_rects_matched = [False for _ in range(self.m)]
rects_matched = [False for _ in range(self.n)]

if self.n == 0:
return [], []
elif self.m == 0:
return rects_matched, []

for i, gt_index in enumerate(np.argmax(self.iou_matrix, axis=1)):
if self.iou_matrix[i, gt_index] >= iou_threshold:
if gt_rects_matched[gt_index] is False and rects_matched[i] is False:
rects_matched[i] = True
gt_rects_matched[gt_index] = True
if self.iou_matrix[i, gt_index] >= iou_threshold and (
gt_rects_matched[gt_index] is False and rects_matched[i] is False
):
rects_matched[i] = True
gt_rects_matched[gt_index] = True
return rects_matched, gt_rects_matched


Expand Down
44 changes: 19 additions & 25 deletions scoring/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ def convert_to_rectangle_list(coordinates):
Outputs:
A list of rectangles
"""
rectangle_list = []
number_of_rects = int(len(coordinates) / 4)
for i in range(number_of_rects):
rectangle_list.append(
Rectangle(coordinates[4 * i], coordinates[4 * i + 1], coordinates[4 * i + 2], coordinates[4 * i + 3])
number_of_rects = len(coordinates) // 4
return [
Rectangle(
coordinates[4 * i],
coordinates[4 * i + 1],
coordinates[4 * i + 2],
coordinates[4 * i + 3],
)
return rectangle_list
for i in range(number_of_rects)
]


def ap_from_pr(p, r):
Expand All @@ -121,10 +124,8 @@ def ap_from_pr(p, r):
if p[i] > p[i - 1]:
p[i - 1] = p[i]

i = np.where(r[1:] != r[: len(r) - 1])[0] + 1
ap = np.sum((r[i] - r[i - 1]) * p[i])

return ap
i = np.where(r[1:] != r[:-1])[0] + 1
return np.sum((r[i] - r[i - 1]) * p[i])


# @profile
Expand Down Expand Up @@ -206,10 +207,7 @@ def score(path_predictions, path_groundtruth, path_output, iou_threshold=0.5):
print("Number of Predictions: %d" % num_preds)
print("Number of GT: %d" % np.sum(gt_classes.shape))

per_file_class_data = {}
for i in gt_unique:
per_file_class_data[i] = [[], []]

per_file_class_data = {i: [[], []] for i in gt_unique}
num_gt_per_cls = np.zeros((max_gt_cls))

attempted = np.zeros(100)
Expand Down Expand Up @@ -432,21 +430,17 @@ def score(path_predictions, path_groundtruth, path_output, iou_threshold=0.5):
with open("data/xview.names") as f:
lines = f.readlines()

map_dict = {}
for i in range(60):
map_dict[lines[i].replace("\n", "")] = average_precision_per_class[int(n[i])]

map_dict = {lines[i].replace("\n", ""): average_precision_per_class[int(n[i])] for i in range(60)}
print(np.nansum(per_class_rcount), map_dict)
vals = {}
vals["map"] = np.nanmean(average_precision_per_class)
vals = {"map": np.nanmean(average_precision_per_class)}
vals["map_score"] = np.nanmean(per_class_p)
vals["mar_score"] = np.nanmean(per_class_r)

a = np.concatenate(
(average_precision_per_class, per_class_p, per_class_r, per_class_rcount, num_gt_per_cls)
).reshape(5, 100)

for i in splits.keys():
for i in splits:
vals[i] = np.nanmean(average_precision_per_class[splits[i]])

v2 = np.zeros((62, 5))
Expand All @@ -469,15 +463,15 @@ def score(path_predictions, path_groundtruth, path_output, iou_threshold=0.5):
# with open(path_output + '/score.txt', 'w') as f:
# f.write(str("%.8f" % vals['map']))
#
with open(path_output + "/metrics.txt", "w") as f:
for key in vals.keys():
f.write("%s %f\n" % (str(key), vals[key]))
with open(f"{path_output}/metrics.txt", "w") as f:
for key, value in vals.items():
f.write("%s %f\n" % (str(key), value))
# for key in vals.keys():
# f.write("%f\n" % (vals[key]))
for i in range(len(v2)):
f.write(("%g, " * 5 + "\n") % (v2[i, 0], v2[i, 1], v2[i, 2], v2[i, 3], v2[i, 4]))

print("Final time: %s" % str(time.time() - ttime))
print(f"Final time: {str(time.time() - ttime)}")


if __name__ == "__main__":
Expand Down
24 changes: 8 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def main(opt):
% ("Epoch", "Batch", "x", "y", "w", "h", "conf", "cls", "total", "P", "R", "nGT", "TP", "FP", "FN", "time")
)
class_weights = xview_class_weights_hard_mining(range(60)).to(device)
n = 4 # number of pictures at a time
for epoch in range(opt.epochs):
epoch += start_epoch

Expand All @@ -121,10 +122,9 @@ def main(opt):
rloss = defaultdict(float) # running loss
metrics = torch.zeros(4, 60)
for i, (imgs, targets) in enumerate(dataloader):
n = 4 # number of pictures at a time
for j in range(int(len(imgs) / n)):
for j in range(len(imgs) // n):
targets_j = targets[j * n : j * n + n]
nGT = sum([len(x) for x in targets_j])
nGT = sum(len(x) for x in targets_j)
if nGT < 1:
continue

Expand All @@ -147,19 +147,11 @@ def main(opt):
# Precision
precision = metrics[0] / (metrics[0] + metrics[1] + 1e-16)
k = (metrics[0] + metrics[1]) > 0
if k.sum() > 0:
mean_precision = precision[k].mean()
else:
mean_precision = 0

mean_precision = precision[k].mean() if k.sum() > 0 else 0
# Recall
recall = metrics[0] / (metrics[0] + metrics[2] + 1e-16)
k = (metrics[0] + metrics[2]) > 0
if k.sum() > 0:
mean_recall = recall[k].mean()
else:
mean_recall = 0

mean_recall = recall[k].mean() if k.sum() > 0 else 0
s = ("%10s%10s" + "%10.3g" * 14) % (
"%g/%g" % (epoch, opt.epochs - 1),
"%g/%g" % (i, len(dataloader) - 1),
Expand All @@ -181,8 +173,8 @@ def main(opt):
t1 = time.time()
print(s)

# if i == 1:
# return
# if i == 1:
# return

# # Update dynamic class weights
# new_weights = metrics[3]
Expand Down Expand Up @@ -218,7 +210,7 @@ def main(opt):

# Save backup checkpoint
if (epoch > 0) & (epoch % 100 == 0):
os.system("cp weights/latest.pt weights/backup" + str(epoch) + ".pt")
os.system(f"cp weights/latest.pt weights/backup{epoch}.pt")

# Save final model
dt = time.time() - t0
Expand Down
Loading

0 comments on commit c6011a0

Please sign in to comment.