Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] Got 2x speed using custom Graph dataloader #3344

Closed
felipemello1 opened this issue Sep 10, 2021 · 1 comment
Closed

[Improvement] Got 2x speed using custom Graph dataloader #3344

felipemello1 opened this issue Sep 10, 2021 · 1 comment

Comments

@felipemello1
Copy link

🚀 Feature

Consider adding FastDataloder to DGL library

#Wall time: 13.3 s

traindataloader = dgl.dataloading.GraphDataLoader(trainDataSet, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
%time next(iter(traindataloader))

#Wall time: 6.72 s

traindataloader = CustomGraphDataLoader(trainDataSet, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
%time next(iter(traindataloader))

Motivation

I got 2x speed using this drop-in replacement extracted from here:

class _RepeatSampler(object):
    """ Sampler that repeats forever.

    Args:
        sampler (Sampler)
    """

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)


class FastDataLoader(torch.utils.data.dataloader.DataLoader):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)
            
class GraphDataLoader:
    """PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
    graph and corresponding label tensor (if provided) of the said minibatch.

    Parameters
    ----------
    collate_fn : Function, default is None
        The customized collate function. Will use the default collate
        function if not given.
    kwargs : dict
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.

    Examples
    --------
    To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
    the backend is PyTorch):

    >>> dataloader = dgl.dataloading.GraphDataLoader(
    ...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for batched_graph, labels in dataloader:
    ...     train_on(batched_graph, labels)
    """
    collator_arglist = inspect.getfullargspec(GraphCollator).args

    def __init__(self, dataset, collate_fn=None, **kwargs):
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v

        if collate_fn is None:
            self.collate = GraphCollator(**collator_kwargs).collate
        else:
            self.collate = collate_fn

        self.dataloader = FastDataLoader(dataset=dataset,
                                     collate_fn=self.collate,
                                     **dataloader_kwargs)

    def __iter__(self):
        """Return the iterator of the data loader."""
        return iter(self.dataloader)

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)

Cons

Apparently it may break in some situations: Lightning-AI/pytorch-lightning#1506
And has some limitations: pytorch/pytorch#15849 (comment)

@felipemello1 felipemello1 changed the title [Improvement] Got 2x speed using custom Graph dataloder [Improvement] Got 2x speed using custom Graph dataloader Sep 10, 2021
@BarclayII
Copy link
Collaborator

Looks like the PyTorch issue suggested that the overhead is due to the re-initialization of worker processes. Did you try persistent_workers option in newer PyTorch versions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants