Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 28, 2021
1 parent dd5a574 commit 53aae7f
Showing 1 changed file with 8 additions and 29 deletions.
37 changes: 8 additions & 29 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import argparse
import logging
import math
import os
import random
import sys
Expand All @@ -16,6 +15,7 @@
from pathlib import Path
from threading import Thread

import math
import numpy as np
import torch.distributed as dist
import torch.nn as nn
Expand Down Expand Up @@ -53,30 +53,6 @@
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))


def get_step_condition(nw, accumulate):
"Supports gradient accumulation and warmup for `nw` iterations"
nw, accumulate = int(nw), int(accumulate)

def acc(ni):
"Ramp up number of accumulated grads from 1 to `accumulate`"
return int(1 + (accumulate - 1) / nw * ni)

warmup_steps, ni = [], nw - 1
while ni >= 0:
warmup_steps.append(ni)
ni -= acc(ni)
warmup_steps = set(warmup_steps)

def step_condition(ni):
"Return whether optimizer.step() is to be called in iteration `ni`"
nonlocal nw, accumulate, warmup_steps
i = ni - nw
if (i >= 0) and ((i + 1) % accumulate == 0): return True
return ni in warmup_steps

return step_condition


def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
Expand Down Expand Up @@ -165,7 +141,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Optimizer
nbs = 64 # nominal batch size
accumulate = max(int(round(nbs / batch_size)), 1) # accumulate loss before optimizing
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")

Expand Down Expand Up @@ -292,8 +268,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Start training
t0 = time.time()
nw = round(hyp['warmup_epochs'] * nb) # number of warmup iterations
do_step = get_step_condition(nw, accumulate)
nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
last_opt_step = -1
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
Expand Down Expand Up @@ -340,6 +317,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if ni <= nw:
xi = [0, nw] # x interp
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
Expand Down Expand Up @@ -367,12 +345,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
scaler.scale(loss).backward()

# Optimize
if do_step(ni):
if ni - last_opt_step >= accumulate:
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()
if ema:
ema.update(model)
last_opt_step = ni

# Print
if RANK in [-1, 0]:
Expand Down

0 comments on commit 53aae7f

Please sign in to comment.