From 53aae7fd85093c92ed94e544cc7d6a6c89ee106d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Jun 2021 12:04:42 +0200 Subject: [PATCH] Update train.py --- train.py | 37 ++++++++----------------------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/train.py b/train.py index 58bbc527928d..01429c0e9b39 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,6 @@ import argparse import logging -import math import os import random import sys @@ -16,6 +15,7 @@ from pathlib import Path from threading import Thread +import math import numpy as np import torch.distributed as dist import torch.nn as nn @@ -53,30 +53,6 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) -def get_step_condition(nw, accumulate): - "Supports gradient accumulation and warmup for `nw` iterations" - nw, accumulate = int(nw), int(accumulate) - - def acc(ni): - "Ramp up number of accumulated grads from 1 to `accumulate`" - return int(1 + (accumulate - 1) / nw * ni) - - warmup_steps, ni = [], nw - 1 - while ni >= 0: - warmup_steps.append(ni) - ni -= acc(ni) - warmup_steps = set(warmup_steps) - - def step_condition(ni): - "Return whether optimizer.step() is to be called in iteration `ni`" - nonlocal nw, accumulate, warmup_steps - i = ni - nw - if (i >= 0) and ((i + 1) % accumulate == 0): return True - return ni in warmup_steps - - return step_condition - - def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, @@ -165,7 +141,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Optimizer nbs = 64 # nominal batch size - accumulate = max(int(round(nbs / batch_size)), 1) # accumulate loss before optimizing + accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") @@ -292,8 +268,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Start training t0 = time.time() - nw = round(hyp['warmup_epochs'] * nb) # number of warmup iterations - do_step = get_step_condition(nw, accumulate) + nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) + # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training + last_opt_step = -1 maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move @@ -340,6 +317,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if ni <= nw: xi = [0, nw] # x interp # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) + accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) @@ -367,12 +345,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary scaler.scale(loss).backward() # Optimize - if do_step(ni): + if ni - last_opt_step >= accumulate: scaler.step(optimizer) # optimizer.step scaler.update() optimizer.zero_grad() if ema: ema.update(model) + last_opt_step = ni # Print if RANK in [-1, 0]: