From a1c8406af3eac3e20d4dd5d327fd6cbd4fbb9752 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 13 Jul 2020 20:19:10 -0700 Subject: [PATCH] EMA and non_blocking=True --- test.py | 2 +- train.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test.py b/test.py index f819bae1c06e..faad3477fd77 100644 --- a/test.py +++ b/test.py @@ -69,7 +69,7 @@ def test(data, loss = torch.zeros(3, device=device) jdict, stats, ap, ap_class = [], [], [], [] for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): - img = img.to(device) + img = img.to(device, non_blocking=True) img = img.half() if half else img.float() # uint8 to fp16/32 img /= 255.0 # 0 - 255 to 0.0 - 1.0 targets = targets.to(device) diff --git a/train.py b/train.py index 85a161155a12..3526c6cc315c 100644 --- a/train.py +++ b/train.py @@ -193,7 +193,7 @@ def train(hyp): check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # Exponential moving average - ema = torch_utils.ModelEMA(model, updates=start_epoch * nb / accumulate) + ema = torch_utils.ModelEMA(model) # Start training t0 = time.time() @@ -223,7 +223,7 @@ def train(hyp): pbar = tqdm(enumerate(dataloader), total=nb) # progress bar for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- ni = i + nb * epoch # number integrated batches (since train start) - imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 + imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 # Warmup if ni <= nw: