Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warmup schedulers in References #4411

Merged
merged 9 commits into from
Sep 17, 2021
23 changes: 22 additions & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,25 @@ def main(args):
opt_level=args.apex_opt_level
)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == 'steplr':
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
elif args.lr_scheduler == 'cosineannealinglr':
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=args.epochs - args.lr_warmup_epochs)
else:
raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR "
"are supported.".format(args.lr_scheduler))

if args.lr_warmup_epochs > 0:
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
total_iters=args.lr_warmup_epochs), main_lr_scheduler],
milestones=[args.lr_warmup_epochs]
)
else:
lr_scheduler = main_lr_scheduler

model_without_ddp = model
if args.distributed:
Expand Down Expand Up @@ -287,6 +305,9 @@ def get_args_parser(add_help=True):
dest='label_smoothing')
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)')
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
Expand Down
3 changes: 2 additions & 1 deletion references/detection/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
warmup_factor = 1. / 1000
warmup_iters = min(1000, len(data_loader) - 1)

lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_factor,
total_iters=warmup_iters)

for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
Expand Down
11 changes: 0 additions & 11 deletions references/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,17 +207,6 @@ def collate_fn(batch):
return tuple(zip(*batch))


def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):

def f(x):
if x >= warmup_iters:
return 1
alpha = float(x) / warmup_iters
return warmup_factor * (1 - alpha) + alpha

return torch.optim.lr_scheduler.LambdaLR(optimizer, f)


def mkdir(path):
try:
os.makedirs(path)
Expand Down
28 changes: 26 additions & 2 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,30 @@ def main(args):
params_to_optimize,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
iters_per_epoch = len(data_loader)
main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue to track if we can now use a stock PyTorch scheduler for this #4438


if args.lr_warmup_epochs > 0:
warmup_iters = iters_per_epoch * args.lr_warmup_epochs
args.lr_warmup_method = args.lr_warmup_method.lower()
if args.lr_warmup_method == 'linear':
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay,
total_iters=warmup_iters)
elif args.lr_warmup_method == 'constant':
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
total_iters=warmup_iters)
else:
raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant "
"are supported.".format(args.lr_warmup_method))
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_lr_scheduler, main_lr_scheduler],
milestones=[warmup_iters]
)
else:
lr_scheduler = main_lr_scheduler

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
Expand Down Expand Up @@ -197,6 +218,9 @@ def get_args_parser(add_help=True):
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, min_size, max_size=None):
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = F.resize(image, size)
target = F.resize(target, size, interpolation=Image.NEAREST)
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
return image, target


Expand Down
47 changes: 0 additions & 47 deletions references/video_classification/scheduler.py

This file was deleted.

35 changes: 27 additions & 8 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import presets
import utils

from scheduler import WarmupMultiStepLR

try:
from apex import amp
except ImportError:
Expand Down Expand Up @@ -202,11 +200,30 @@ def main(args):

# convert scheduler to be per iteration, not per epoch, for warmup that lasts
# between different epochs
warmup_iters = args.lr_warmup_epochs * len(data_loader)
lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
lr_scheduler = WarmupMultiStepLR(
optimizer, milestones=lr_milestones, gamma=args.lr_gamma,
warmup_iters=warmup_iters, warmup_factor=1e-5)
iters_per_epoch = len(data_loader)
lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones]
main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma)

if args.lr_warmup_epochs > 0:
warmup_iters = iters_per_epoch * args.lr_warmup_epochs
args.lr_warmup_method = args.lr_warmup_method.lower()
if args.lr_warmup_method == 'linear':
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay,
total_iters=warmup_iters)
elif args.lr_warmup_method == 'constant':
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
total_iters=warmup_iters)
else:
raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant "
"are supported.".format(args.lr_warmup_method))

lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_lr_scheduler, main_lr_scheduler],
milestones=[warmup_iters]
)
else:
lr_scheduler = main_lr_scheduler

model_without_ddp = model
if args.distributed:
Expand Down Expand Up @@ -277,7 +294,9 @@ def parse_args():
dest='weight_decay')
parser.add_argument('--lr-milestones', nargs='+', default=[20, 30, 40], type=int, help='decrease lr on milestones')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs')
parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)')
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
parser.add_argument('--lr-warmup-decay', default=0.001, type=int, help='the decay for lr')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
Expand Down