-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
20 lines (18 loc) · 819 Bytes
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from functools import reduce
import torch.nn.functional as F
import torch
import torch.nn as nn
def Optimizer(args, model):
if args.optimizer == 'Adam':
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'SGD':
optimizer = torch.optim.SGD(params=model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
return optimizer
def Scheduler(args, optimizer):
if args.scheduler == 'Reduce':
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=args.lr_factor, patience=args.patience)
elif args.scheduler == 'Step':
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=2, gamma=0.9)
return scheduler