Skip to content

Commit

Permalink
Remove opt from create_dataloader()` (ultralytics#3552)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 9, 2021
1 parent c41fddc commit a708b90
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test(data,
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True,
dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True,
prefix=colorstr(f'{task}: '))[0]

seen = 0
Expand Down
17 changes: 9 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@

def train(hyp, opt, device, tb_writer=None):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
opt.single_cls

# Directories
wdir = save_dir / 'weights'
Expand Down Expand Up @@ -75,8 +76,8 @@ def train(hyp, opt, device, tb_writer=None):
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming

nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset

Expand Down Expand Up @@ -187,7 +188,7 @@ def train(hyp, opt, device, tb_writer=None):
logger.info('Using SyncBatchNorm()')

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers,
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
Expand All @@ -197,7 +198,7 @@ def train(hyp, opt, device, tb_writer=None):

# Process 0
if rank in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
pad=0.5, prefix=colorstr('val: '))[0]
Expand Down Expand Up @@ -357,7 +358,7 @@ def train(hyp, opt, device, tb_writer=None):
batch_size=batch_size * 2,
imgsz=imgsz_test,
model=ema.ema,
single_cls=opt.single_cls,
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
Expand Down Expand Up @@ -429,7 +430,7 @@ def train(hyp, opt, device, tb_writer=None):
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=opt.single_cls,
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
Expand Down
6 changes: 3 additions & 3 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ def exif_size(img):
return s


def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
single_cls=single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,
Expand Down

0 comments on commit a708b90

Please sign in to comment.