From 958ab92dc1a29f41f4c813937fda2bc99e1f147b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 9 Jun 2021 13:14:56 +0200 Subject: [PATCH] Remove `opt` from `create_dataloader()`` (#3552) --- test.py | 2 +- train.py | 17 +++++++++-------- utils/datasets.py | 6 +++--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/test.py b/test.py index 515b984bc7be..16a31fd17a54 100644 --- a/test.py +++ b/test.py @@ -88,7 +88,7 @@ def test(data, if device.type != 'cpu': model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images - dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True, + dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=0.5, rect=True, prefix=colorstr(f'{task}: '))[0] seen = 0 diff --git a/train.py b/train.py index aad8ff258d6e..2465a8c22a37 100644 --- a/train.py +++ b/train.py @@ -41,8 +41,9 @@ def train(hyp, opt, device, tb_writer=None): logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) - save_dir, epochs, batch_size, total_batch_size, weights, rank = \ - Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank + save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ + Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ + opt.single_cls # Directories wdir = save_dir / 'weights' @@ -75,8 +76,8 @@ def train(hyp, opt, device, tb_writer=None): if wandb_logger.wandb: weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming - nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes - names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names + nc = 1 if single_cls else int(data_dict['nc']) # number of classes + names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset @@ -187,7 +188,7 @@ def train(hyp, opt, device, tb_writer=None): logger.info('Using SyncBatchNorm()') # Trainloader - dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, + dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, world_size=opt.world_size, workers=opt.workers, image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) @@ -197,7 +198,7 @@ def train(hyp, opt, device, tb_writer=None): # Process 0 if rank in [-1, 0]: - testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader + testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls, hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, world_size=opt.world_size, workers=opt.workers, pad=0.5, prefix=colorstr('val: '))[0] @@ -357,7 +358,7 @@ def train(hyp, opt, device, tb_writer=None): batch_size=batch_size * 2, imgsz=imgsz_test, model=ema.ema, - single_cls=opt.single_cls, + single_cls=single_cls, dataloader=testloader, save_dir=save_dir, save_json=is_coco and final_epoch, @@ -429,7 +430,7 @@ def train(hyp, opt, device, tb_writer=None): conf_thres=0.001, iou_thres=0.7, model=attempt_load(m, device).half(), - single_cls=opt.single_cls, + single_cls=single_cls, dataloader=testloader, save_dir=save_dir, save_json=True, diff --git a/utils/datasets.py b/utils/datasets.py index 108005c8de65..444b3ff2f60c 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -62,8 +62,8 @@ def exif_size(img): return s -def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, - rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''): +def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, + rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''): # Make sure only the first process in DDP process the dataset first, and the following others can use the cache with torch_distributed_zero_first(rank): dataset = LoadImagesAndLabels(path, imgsz, batch_size, @@ -71,7 +71,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa hyp=hyp, # augmentation hyperparameters rect=rect, # rectangular training cache_images=cache, - single_cls=opt.single_cls, + single_cls=single_cls, stride=int(stride), pad=pad, image_weights=image_weights,