From ef8de4c80bd71436ab2640450b5f90245fab2882 Mon Sep 17 00:00:00 2001 From: bilzard <36561962+bilzard@users.noreply.github.com> Date: Sun, 2 Jan 2022 11:28:13 +0900 Subject: [PATCH] enable AdamW optimizer --- train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index e2cd5ec85c09..304c001b6547 100644 --- a/train.py +++ b/train.py @@ -22,7 +22,7 @@ import yaml from torch.cuda import amp from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import SGD, Adam, lr_scheduler +from torch.optim import SGD, Adam, AdamW, lr_scheduler from tqdm import tqdm FILE = Path(__file__).resolve() @@ -155,8 +155,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay) g1.append(v.weight) - if opt.adam: + if opt.optimizer == 'Adam': optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum + elif opt.optimizer == 'AdamW': + optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum else: optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) @@ -460,7 +462,7 @@ def parse_opt(known=False): parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') - parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') + parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer') parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')