Skip to content

Commit

Permalink
Create one_cycle() function (#1836)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jan 4, 2021
1 parent 7dddb1d commit 0e341c5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
print_mutation, set_logging
print_mutation, set_logging, one_cycle
from utils.google_utils import attempt_download
from utils.loss import compute_loss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
Expand Down Expand Up @@ -126,12 +126,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# Scheduler https://arxiv.org/pdf/1812.01187.pdf
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# plot_lr_scheduler(optimizer, scheduler, epochs)

# Logging
if wandb and wandb.run is None:
if rank in [-1, 0] and wandb and wandb.run is None:
opt.hyp = hyp # add hyperparameters
wandb_run = wandb.init(config=opt, resume="allow",
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
Expand Down
5 changes: 5 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def clean_str(s):
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)


def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1


def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
if labels[0] is None: # no labels loaded
Expand Down
1 change: 1 addition & 0 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
plt.xlim(0, epochs)
plt.ylim(0)
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
plt.close()


def plot_test_txt(): # from utils.plots import *; plot_test()
Expand Down

0 comments on commit 0e341c5

Please sign in to comment.