Skip to content

Commit

Permalink
Fix redundant outputs via Logging in DDP training (ultralytics#500)
Browse files Browse the repository at this point in the history
* Change print to logging

* Clean function set_logging

* Add line spacing

* Change leftover prints to log

* Fix scanning labels output

* Fix rank naming

* Change leftover print to logging

* Reorganized DDP variables

* Fix type error

* Make quotes consistent

* Fix spelling

* Clean function call

* Add line spacing

* Update datasets.py

* Update train.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
NanoCode012 and glenn-jocher committed Aug 11, 2020
1 parent 38bceec commit 15cf413
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 39 deletions.
6 changes: 4 additions & 2 deletions models/yolo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import math
import logging
from copy import deepcopy
from pathlib import Path

Expand All @@ -12,6 +13,7 @@
from utils.torch_utils import (
time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device)

logger = logging.getLogger(__name__)

class Detect(nn.Module):
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
Expand Down Expand Up @@ -169,7 +171,7 @@ def info(self): # print model information


def parse_model(d, ch): # model_dict, input_channels(3)
print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
logger.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
Expand Down Expand Up @@ -224,7 +226,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum([x.numel() for x in m_.parameters()]) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
print('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
ch.append(c2)
Expand Down
50 changes: 26 additions & 24 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import random
import time
import logging
from pathlib import Path

import numpy as np
Expand All @@ -23,13 +24,14 @@
from utils.general import (
torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights,
compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file,
check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution)
check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution, set_logging)
from utils.google_utils import attempt_download
from utils.torch_utils import init_seeds, ModelEMA, select_device, intersect_dicts

logger = logging.getLogger(__name__)

def train(hyp, opt, device, tb_writer=None):
print(f'Hyperparameters {hyp}')
logger.info(f'Hyperparameters {hyp}')
log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
wdir = str(log_dir / 'weights') + os.sep # weights directory
os.makedirs(wdir, exist_ok=True)
Expand Down Expand Up @@ -69,7 +71,7 @@ def train(hyp, opt, device, tb_writer=None):
state_dict = ckpt['model'].float().state_dict() # to FP32
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(state_dict, strict=False) # load
print('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
logging.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
else:
model = Model(opt.cfg, ch=3, nc=nc).to(device) # create

Expand Down Expand Up @@ -103,7 +105,7 @@ def train(hyp, opt, device, tb_writer=None):

optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
del pg0, pg1, pg2

# Scheduler https://arxiv.org/pdf/1812.01187.pdf
Expand All @@ -128,7 +130,7 @@ def train(hyp, opt, device, tb_writer=None):
# Epochs
start_epoch = ckpt['epoch'] + 1
if epochs < start_epoch:
print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
(weights, ckpt['epoch'], epochs))
epochs += ckpt['epoch'] # finetune additional epochs

Expand All @@ -145,7 +147,7 @@ def train(hyp, opt, device, tb_writer=None):
# SyncBatchNorm
if opt.sync_bn and cuda and rank != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
print('Using SyncBatchNorm()')
logger.info('Using SyncBatchNorm()')

# Exponential moving average
ema = ModelEMA(model) if rank in [-1, 0] else None
Expand All @@ -156,7 +158,7 @@ def train(hyp, opt, device, tb_writer=None):

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
cache=opt.cache_images, rect=opt.rect, local_rank=rank,
cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
Expand All @@ -166,7 +168,7 @@ def train(hyp, opt, device, tb_writer=None):
if rank in [-1, 0]:
# local_rank is set to -1. Because only the first process is expected to do evaluation.
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size)[0]

# Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
Expand Down Expand Up @@ -199,10 +201,9 @@ def train(hyp, opt, device, tb_writer=None):
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
if rank in [0, -1]:
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % dataloader.num_workers)
print('Starting training for %g epochs...' % epochs)
logger.info('Image sizes %g train, %g test' % (imgsz, imgsz_test))
logger.info('Using %g dataloader workers' % dataloader.num_workers)
logger.info('Starting training for %g epochs...' % epochs)
# torch.autograd.set_detect_anomaly(True)
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
model.train()
Expand Down Expand Up @@ -232,8 +233,8 @@ def train(hyp, opt, device, tb_writer=None):
if rank != -1:
dataloader.sampler.set_epoch(epoch)
pbar = enumerate(dataloader)
logging.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
if rank in [-1, 0]:
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
Expand Down Expand Up @@ -269,7 +270,7 @@ def train(hyp, opt, device, tb_writer=None):
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
# if not torch.isfinite(loss):
# print('WARNING: non-finite loss, ending training ', loss_items)
# logger.info('WARNING: non-finite loss, ending training ', loss_items)
# return results

# Backward
Expand Down Expand Up @@ -369,7 +370,7 @@ def train(hyp, opt, device, tb_writer=None):
# Finish
if not opt.evolve:
plot_results(save_dir=log_dir) # save as results.png
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))

dist.destroy_process_group() if rank not in [-1, 0] else None
torch.cuda.empty_cache()
Expand Down Expand Up @@ -404,13 +405,19 @@ def train(hyp, opt, device, tb_writer=None):
parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
opt = parser.parse_args()

# Set DDP variables
opt.total_batch_size = opt.batch_size
opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
set_logging(opt.global_rank)

# Resume
if opt.resume:
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights:
print(f'Resuming training from {last}')
logger.info(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights
if opt.local_rank == -1 or ("RANK" in os.environ and os.environ["RANK"] == "0"):
if opt.global_rank in [-1,0]:
check_git_status()

opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
Expand All @@ -419,30 +426,25 @@ def train(hyp, opt, device, tb_writer=None):

opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = select_device(opt.device, batch_size=opt.batch_size)
opt.total_batch_size = opt.batch_size
opt.world_size = 1
opt.global_rank = -1

# DDP mode
if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device('cuda', opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
opt.world_size = dist.get_world_size()
opt.global_rank = dist.get_rank()
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
opt.batch_size = opt.total_batch_size // opt.world_size

print(opt)
logger.info(opt)
with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps

# Train
if not opt.evolve:
tb_writer = None
if opt.global_rank in [-1, 0]:
print('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir)
logger.info('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir)
tb_writer = SummaryWriter(log_dir=increment_dir(Path(opt.logdir) / 'exp', opt.name)) # runs/exp

train(hyp, opt, device, tb_writer)
Expand Down
22 changes: 13 additions & 9 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,22 @@ def exif_size(img):


def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
local_rank=-1, world_size=1):
rank=-1, world_size=1):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
with torch_distributed_zero_first(local_rank):
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
stride=int(stride),
pad=pad)
pad=pad,
rank=rank)

batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, 8]) # number of workers
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
Expand Down Expand Up @@ -292,7 +293,7 @@ def __len__(self):

class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0):
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
try:
f = [] # image files
for p in path if isinstance(path, list) else [path]:
Expand Down Expand Up @@ -372,8 +373,10 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
# Cache labels
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
pbar = tqdm(self.label_files)
for i, file in enumerate(pbar):
pbar = enumerate(self.label_files)
if rank in [-1, 0]:
pbar = tqdm(pbar)
for i, file in pbar:
l = self.labels[i] # label
if l is not None and l.shape[0]:
assert l.shape[1] == 5, '> 5 label columns: %s' % file
Expand Down Expand Up @@ -420,8 +423,9 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove

pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
cache_path, nf, nm, ne, nd, n)
if rank in [-1,0]:
pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
cache_path, nf, nm, ne, nd, n)
if nf == 0:
s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
print(s)
Expand Down
7 changes: 7 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import subprocess
import time
import logging
from contextlib import contextmanager
from copy import copy
from pathlib import Path
Expand Down Expand Up @@ -45,6 +46,12 @@ def torch_distributed_zero_first(local_rank: int):
torch.distributed.barrier()


def set_logging(rank=-1):
logging.basicConfig(
format="%(message)s",
level=logging.INFO if rank in [-1, 0] else logging.WARN)


def init_seeds(seed=0):
random.seed(seed)
np.random.seed(seed)
Expand Down
10 changes: 6 additions & 4 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import os
import time
import logging
from copy import deepcopy

import torch
Expand All @@ -9,6 +10,7 @@
import torch.nn.functional as F
import torchvision.models as models

logger = logging.getLogger(__name__)

def init_seeds(seed=0):
torch.manual_seed(seed)
Expand Down Expand Up @@ -40,12 +42,12 @@ def select_device(device='', batch_size=None):
for i in range(0, ng):
if i == 1:
s = ' ' * len(s)
print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
logger.info("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
(s, i, x[i].name, x[i].total_memory / c))
else:
print('Using CPU')
logger.info('Using CPU')

print('') # skip a line
logger.info('') # skip a line
return torch.device('cuda:0' if cuda else 'cpu')


Expand Down Expand Up @@ -142,7 +144,7 @@ def model_info(model, verbose=False):
except:
fs = ''

print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
logger.info('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))


def load_classifier(name='resnet101', n=2):
Expand Down

0 comments on commit 15cf413

Please sign in to comment.