Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No cache option for reading datasets #4376

Merged
merged 6 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
device,
callbacks=Callbacks()
):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, read_data_from_cache = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze, opt.read_data_from_cache

# Directories
w = save_dir / 'weights' # weights dir
Expand Down Expand Up @@ -203,7 +203,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: '))
prefix=colorstr('train: '), read_data_from_cache=read_data_from_cache)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(train_loader) # number of batches
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
Expand Down Expand Up @@ -452,6 +452,8 @@ def parse_opt(known=False):
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
parser.add_argument('--cache', action='store_true', help='Read data from Cache if exists. Default=True')

opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt

Expand Down
20 changes: 13 additions & 7 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def exif_transpose(image):


def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', read_data_from_cache=True):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
Expand All @@ -101,7 +101,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)
prefix=prefix, read_data_from_cache=read_data_from_cache)

batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
Expand Down Expand Up @@ -361,7 +361,7 @@ def img2label_paths(img_paths):

class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix='', read_data_from_cache=True):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
Expand Down Expand Up @@ -397,12 +397,18 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
# Check cache
self.label_files = img2label_paths(self.img_files) # labels
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == 0.4 and cache['hash'] == get_hash(self.label_files + self.img_files)
except:

if read_data_from_cache: #if false, data would be read from scratch even if cache exists
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == 0.4 and cache['hash'] == get_hash(self.label_files + self.img_files)
except:
print("No Cache Exists, Reading from Disc instead")
cache, exists = self.cache_labels(cache_path, prefix), False # cache
else:
cache, exists = self.cache_labels(cache_path, prefix), False # cache


# Display cache
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
if exists:
Expand Down