Skip to content

Commit

Permalink
EMA and non_blocking=True
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jul 14, 2020
1 parent 2377e5f commit a1c8406
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test(data,
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], []
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
img = img.to(device)
img = img.to(device, non_blocking=True)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
targets = targets.to(device)
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def train(hyp):
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)

# Exponential moving average
ema = torch_utils.ModelEMA(model, updates=start_epoch * nb / accumulate)
ema = torch_utils.ModelEMA(model)

# Start training
t0 = time.time()
Expand Down Expand Up @@ -223,7 +223,7 @@ def train(hyp):
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
ni = i + nb * epoch # number integrated batches (since train start)
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0

# Warmup
if ni <= nw:
Expand Down

0 comments on commit a1c8406

Please sign in to comment.