Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
train255 authored and glenn-jocher committed Feb 5, 2021
1 parent 4bdc5a3 commit ed5d9c1
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test(data,
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True,
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, data_type="test", pad=0.5, rect=True,
prefix=colorstr('test: ' if opt.task == 'test' else 'val: '))[0]

seen = 0
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers,
world_size=opt.world_size, workers=opt.workers, data_type="train",
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
Expand All @@ -192,7 +192,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
ema.updates = start_epoch * nb // accumulate # set EMA updates
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
world_size=opt.world_size, workers=opt.workers, data_type="val",
pad=0.5, prefix=colorstr('val: '))[0]

if not opt.resume:
Expand Down
2 changes: 1 addition & 1 deletion utils/autoanchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def print_results(k):
with open(path) as f:
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict
from utils.datasets import LoadImagesAndLabels
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
dataset = LoadImagesAndLabels(data_dict['train'], data_type="train", augment=True, rect=True)
else:
dataset = path # dataset

Expand Down
10 changes: 6 additions & 4 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def exif_size(img):


def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
rank=-1, world_size=1, workers=8, data_type="train", image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
dataset = LoadImagesAndLabels(path, imgsz, batch_size, data_type=data_type,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
Expand Down Expand Up @@ -337,7 +337,7 @@ def img2label_paths(img_paths):


class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
def __init__(self, path, img_size=640, batch_size=16, data_type="train", augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
self.img_size = img_size
self.augment = augment
Expand Down Expand Up @@ -372,7 +372,9 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r

# Check cache
self.label_files = img2label_paths(self.img_files) # labels
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') # cached labels
cache_basename = os.path.basename(os.path.dirname(self.label_files[0]))
new_cache = self.label_files[0].replace(cache_basename, data_type)
cache_path = Path(new_cache).parent.with_suffix('.cache') # cached labels
if cache_path.is_file():
cache = torch.load(cache_path) # load
if cache['hash'] != get_hash(self.label_files + self.img_files) or 'results' not in cache: # changed
Expand Down
4 changes: 2 additions & 2 deletions utils/wandb_logging/log_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def create_dataset_artifact(opt):
logger = WandbLogger(opt, '', None, data, job_type='create_dataset')
nc, names = (1, ['item']) if opt.single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary
logger.log_dataset_artifact(LoadImagesAndLabels(data['train']), names, name='train') # trainset
logger.log_dataset_artifact(LoadImagesAndLabels(data['val']), names, name='val') # valset
logger.log_dataset_artifact(LoadImagesAndLabels(data['train'], data_type="train"), names, name='train') # trainset
logger.log_dataset_artifact(LoadImagesAndLabels(data['val'], data_type="val"), names, name='val') # valset

# Update data.yaml with artifact links
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(opt.project) / 'train')
Expand Down

0 comments on commit ed5d9c1

Please sign in to comment.