Skip to content

Commit

Permalink
Ultralytics Code Refactor https://ultralytics.com/actions (#2235)
Browse files Browse the repository at this point in the history
Refactor code for speed and clarity
  • Loading branch information
glenn-jocher committed Jun 30, 2024
1 parent a6e437c commit b93ce58
Show file tree
Hide file tree
Showing 13 changed files with 30 additions and 5 deletions.
1 change: 1 addition & 0 deletions classify/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def train(opt, device):

# lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine
def lf(x):
"""Linear learning rate scheduler function, scaling learning rate from initial value to `lrf` over `epochs`."""
return (1 - x / epochs) * (1 - lrf) + lrf # linear

scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
Expand Down
5 changes: 4 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def try_export(inner_func):
inner_args = get_default_args(inner_func)

def outer_func(*args, **kwargs):
"""Profiles and logs the export process of YOLOv3 models, capturing success or failure details."""
prefix = inner_args["prefix"]
try:
with Profile() as dt:
Expand Down Expand Up @@ -226,7 +227,7 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr("ONNX

@try_export
def export_openvino(file, metadata, half, int8, data, prefix=colorstr("OpenVINO:")):
# YOLOv3 OpenVINO export
"""Exports a YOLOv3 model to OpenVINO format, with optional INT8 quantization and inference metadata."""
check_requirements("openvino-dev>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.runtime as ov # noqa
from openvino.tools import mo # noqa
Expand All @@ -247,6 +248,7 @@ def export_openvino(file, metadata, half, int8, data, prefix=colorstr("OpenVINO:
onnx_model = core.read_model(f_onnx) # export

def prepare_input_tensor(image: np.ndarray):
"""Prepares the input tensor by normalizing pixel values and converting the datatype to float32."""
input_tensor = image.astype(np.float32) # uint8 to fp16/32
input_tensor /= 255.0 # 0 - 255 to 0.0 - 1.0

Expand All @@ -255,6 +257,7 @@ def prepare_input_tensor(image: np.ndarray):
return input_tensor

def gen_dataloader(yaml_path, task="train", imgsz=640, workers=4):
"""Generates a PyTorch dataloader for the specified task using dataset configurations from a YAML file."""
data_yaml = check_yaml(yaml_path)
data = check_dataset(data_yaml)
dataloader = create_dataloader(
Expand Down
1 change: 1 addition & 0 deletions segment/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
else:

def lf(x):
"""Linear learning rate scheduler decreasing from 1 to hyp['lrf'] over the course of given epochs."""
return (1 - x / epochs) * (1.0 - hyp["lrf"]) + hyp["lrf"] # linear

scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
Expand Down
1 change: 1 addition & 0 deletions segment/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def save_one_json(predn, jdict, path, class_map, pred_masks):
from pycocotools.mask import encode

def single_encode(x):
"""Encodes a binary mask to COCO RLE format, converting counts to a UTF-8 string for JSON serialization."""
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
else:

def lf(x):
"""Linear learning rate scheduler function with decay calculated by epoch proportion."""
return (1 - x / epochs) * (1.0 - hyp["lrf"]) + hyp["lrf"] # linear

scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
Expand Down
5 changes: 5 additions & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def threaded(func):
"""

def wrapper(*args, **kwargs):
"""
Runs the decorated function in a separate thread and returns the thread object.
Usage: @threaded.
"""
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
Expand Down
6 changes: 6 additions & 0 deletions utils/autoanchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh

def metric(k): # compute metric
"""Computes and returns best possible recall (bpr) and anchors above threshold (aat) metrics for given
anchors.
"""
r = wh[:, None] / k[None]
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
best = x.max(1)[0] # best_x
Expand Down Expand Up @@ -86,16 +89,19 @@ def kmean_anchors(dataset="./data/coco128.yaml", n=9, img_size=640, thr=4.0, gen
thr = 1 / thr

def metric(k, wh): # compute metrics
"""Computes best possible recall (BPR) and anchors above threshold (AAT) metrics for given anchor boxes."""
r = wh[:, None] / k[None]
x = torch.min(r, 1 / r).min(2)[0] # ratio metric
# x = wh_iou(wh, torch.tensor(k)) # iou metric
return x, x.max(1)[0] # x, best_x

def anchor_fitness(k): # mutation fitness
"""Evaluates the fitness of anchor boxes by computing mean recall weighted by an activation threshold."""
_, best = metric(torch.tensor(k, dtype=torch.float32), wh)
return (best * (best > thr).float()).mean() # fitness

def print_results(k, verbose=True):
"""Displays sorted anchors and their metrics including best possible recall and anchors above threshold."""
k = k[np.argsort(k.prod(1))] # sort small to large
x, best = metric(k, wh0)
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
Expand Down
2 changes: 1 addition & 1 deletion utils/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def attempt_download(file, repo="ultralytics/yolov5", release="v7.0"):
from utils.general import LOGGER

def github_assets(repository, version="latest"):
# Return GitHub repo tag (i.e. 'v7.0') and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
"""Returns GitHub tag and assets for a given repository and version from the GitHub API."""
if version != "latest":
version = f"tags/{version}" # i.e. tags/v7.0
response = requests.get(f"https://github.com/gitapi/repos/{repository}/releases/{version}").json() # github api
Expand Down
8 changes: 5 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def check_online():
import socket

def run_once():
# Check once
"""Attempts a single internet connectivity check to '1.1.1.1' on port 443 and returns True if successful."""
try:
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
return True
Expand Down Expand Up @@ -584,7 +584,7 @@ def check_amp(model):
from models.common import AutoShape, DetectMultiBackend

def amp_allclose(model, im):
# All close FP32 vs AMP results
"""Compares FP32 and AMP inference results for a model and image, ensuring outputs are within 10% tolerance."""
m = AutoShape(model, verbose=False) # model
a = m(im).xywhn[0] # FP32 inference
m.amp = True
Expand Down Expand Up @@ -645,7 +645,9 @@ def download(url, dir=".", unzip=True, delete=True, curl=False, threads=1, retry
"""

def download_one(url, dir):
# Download 1 file
"""Downloads a file from a URL into the specified directory, supporting retries and using curl or torch
methods.
"""
success = True
if os.path.isfile(url):
f = Path(url) # filename
Expand Down
1 change: 1 addition & 0 deletions utils/loggers/comet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class CometLogger:
"""Log metrics, parameters, source code, models and much more with Comet."""

def __init__(self, opt, hyp, run_id=None, job_type="Training", **experiment_kwargs) -> None:
"""Initialize the CometLogger instance with experiment configurations and hyperparameters for logging."""
self.job_type = job_type
self.opt = opt
self.hyp = hyp
Expand Down
1 change: 1 addition & 0 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):

# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
def butter_lowpass(cutoff, fs, order):
"""Applies a low-pass Butterworth filter to input data using forward-backward method; see https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy."""
nyq = 0.5 * fs
normal_cutoff = cutoff / nyq
return butter(order, normal_cutoff, btype="low", analog=False)
Expand Down
2 changes: 2 additions & 0 deletions utils/segment/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def ap_per_class_box_and_mask(

class Metric:
def __init__(self) -> None:
"""Initializes Metric class attributes for precision, recall, F1 score, AP values, and AP class indices."""
self.p = [] # (nc, )
self.r = [] # (nc, )
self.f1 = [] # (nc, )
Expand Down Expand Up @@ -153,6 +154,7 @@ class Metrics:
"""Metric for boxes and masks."""

def __init__(self) -> None:
"""Initializes the Metrics class with separate Metric instances for boxes and masks."""
self.metric_box = Metric()
self.metric_mask = Metric()

Expand Down
1 change: 1 addition & 0 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, "1.9.0")):
"""Applies torch.inference_mode() if torch>=1.9.0 or torch.no_grad() otherwise as a decorator to functions."""

def decorate(fn):
"""Applies torch.inference_mode() if torch>=1.9.0, otherwise torch.no_grad(), as a decorator to functions."""
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)

return decorate
Expand Down

0 comments on commit b93ce58

Please sign in to comment.