Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a class balanced distributed sampler #1844

Closed
wants to merge 1 commit into from

Conversation

NatanBagrov
Copy link
Contributor

@NatanBagrov NatanBagrov commented Feb 18, 2024

Context: in cases where some classes are less frequent than others, we'd want to sample images containing these classes more often.

Current implementation:

  1. IndexMappingDatasetWrapper - gets a dataset and a list of indices (mapping), and wraps the dataset: wrapper[i] == self.dataset[self.mapping[item]]
  2. An implementation of https://arxiv.org/pdf/1908.03195.pdf, that returns a float per index (image) indicating the scarcity of classes that are in that image (larger = more scarce = repeat it more often).
  3. DetectionClassBalancedDistributedSampler that uses (1) with (2) for Detection datasets.

Discussion:
Current implementation supports distributed mode. Non-distributed mode can also be supported, but will result more code.
Some alternative implementations I can spot:

  1. Do not wrap original dataset, but rather override __iter__ from DistributedSampler - it is possible, however, causes coupling and code duplication. Current impelentation (super().__init__(dataset=IndexMappingDatasetWrapper(dataset, repeat_indices), *args, **kwargs)) is a bit less verbose.
  2. Eventually, we might use WeightedSampler, however, it is not supported in distributed mode. There are some implementations of DistributedWeightedSampler in-the-wild. We can try these.
  3. Perhaps we should consider a different approach where the heavy lifting is on the dataset size. For example:
    dataset: class_balanced_wrapper
       inner_dataset: coco
       ...
    
    The issue, however, is that we cannot inherit inner_class, no inherit DetectionDataset because we don't want to pass all the __init__ parameters.
  4. Another option is to modify DetectionDataset and add a parameter for class balancing (i.e., oversampling_threshold) and do the heavy lifting inside DetectionDataset. This is less verbose, however, adds more responsibility to DetectionDataset (compared to Composition)

Open for discussion :)



@register_sampler(Samplers.DISTRIBUTED_DETECTION_CLASS_BALANCING)
class DetectionClassBalancedDistributedSampler(DistributedSampler):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a dedicated module for our samplers, please move it there.

return repeat_factors


class IndexMappingDatasetWrapper(Dataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel a bit unease about having a different dataset passed to the sampler then what we actually have in our data loader.
If anything happens under the hood (which I think does) things could go wrong.
I got a feeling @BloodAxe would be on the same page as me here, but what do you say ?

return torch.tensor([idx, 0]) # class 0 appears everywhere, other classes appear only once.


class ClassBalancingTest(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add to suite.

return torch.tensor(idx)


class DatasetIndexMappingTest(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add to suite.

from super_gradients.training.datasets.detection_datasets.detection_dataset_class_balancing_wrapper import DetectionClassBalancedDistributedSampler


class DummyDetectionDataset(Dataset): # NOTE: we implement the needed stuff from DetectionDataset, but we do not inherit it because the ctor is massive
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add to suite.


from super_gradients.training.datasets.balancing_classes_utils import get_repeat_factors


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a comprehensive test that actually runs in DDP is missing.
Please add one (can be done through the ci config - see sanity_tests workflow).

Copy link
Contributor Author

@NatanBagrov NatanBagrov Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pff.. I agree. I'll try to implement it once we agree on the approach (i.e., is will be redundant if we decide on option 3 or 4)

2. For each category c, compute the category-level repeat factor: :math:`r(c) = max(1, sqrt(t/f(c)))`
3. For each image I, compute the image-level repeat factor: :math:`r(I) = max_{c in I} r(c)`

Returns a list of repeat factors (length = dataset_length). How to read: result[i] is a float, indicates the repeat factor of image i.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity - I guess this would mean that using Mosaic / mixup would mess things up here?
We use them in most of our recipes, so maybe there's something to tak in account here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, you are half correct :) I believe things should work ok because the balancing is done (currently) in the sampler-level, while mixup/mosaic are dataset-level. On the other hand, the sampler will "balance" and ask for more indices with scarce classes (this is good), but the dataset does not know that it should sample non uniformly, thus 3/4 images will be taken at random, regardless of class scarcity. Hope it makes sense...

for dataset_idx, repeat_factor in enumerate(repeat_factors):
repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor))

super().__init__(dataset=IndexMappingDatasetWrapper(dataset, repeat_indices), *args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following my precious comment regarding the dataset wrapper, I am in favor moving the logic introduced there (mapping etc) to this class.
Also the helper methods which I don't think have much context outside this speific sampler.

@BloodAxe
Copy link
Collaborator

First of all - great initiative on adding a samplers support! I feel this can be really helpful in many cases.
Regarding the suggested implementation - I do have a few comments:

  1. I will start with a most simple one. I think we should have a generic wrapper around any sampler for DPP case: DistributedSamplerWrapper

Suppose we have a sampler (any subclass of a Sampler), then for DDP case we would do:
sampler = DistributedSamplerWrapper(sampler) and that would make sampler compatible with DDP.

The suggested design of wrapper was proposed on pytorch forums, been tested for years and myself and does not require any special knowledge of underlying sampler implementation.

class DatasetFromSampler(Dataset):
    """Dataset to create indexes from `Sampler`.

    Args:
        sampler: PyTorch sampler
    """

    def __init__(self, sampler: Sampler):
        """Initialisation for DatasetFromSampler."""
        self.sampler = sampler
        self.sampler_list = None

    def __getitem__(self, index: int):
        """Gets element of the dataset.

        Args:
            index: index of the element in the dataset

        Returns:
            Single element by index
        """
        if self.sampler_list is None:
            self.sampler_list = list(self.sampler)
        return self.sampler_list[index]

    def __len__(self) -> int:
        """
        Returns:
            int: length of the dataset
        """
        return len(self.sampler)


class DistributedSamplerWrapper(DistributedSampler):
    """
    Wrapper over `Sampler` for distributed training.
    Allows you to use any sampler in distributed mode.

    It is especially useful in conjunction with
    `torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSamplerWrapper instance as a DataLoader
    sampler, and load a subset of subsampled data of the original dataset
    that is exclusive to it.

    .. note::
        Sampler is assumed to be of constant size.
    """

    def __init__(
        self,
        sampler,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
    ):
        """

        Args:
            sampler: Sampler used for subsampling
            num_replicas (int, optional): Number of processes participating in
              distributed training
            rank (int, optional): Rank of the current process
              within ``num_replicas``
            shuffle (bool, optional): If true (default),
              sampler will shuffle the indices
        """
        super(DistributedSamplerWrapper, self).__init__(
            DatasetFromSampler(sampler),
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
        )
        self.sampler = sampler

    def __iter__(self):

        self.dataset = DatasetFromSampler(self.sampler)
        indexes_of_indexes = super().__iter__()
        subsampler_indexes = self.dataset
        return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))

@BloodAxe
Copy link
Collaborator

Secondly, I think we should introduce some sort of the interface:

class HasSamplingInformation(ABC):
   def getLabelsPresense() -> np.ndarray:
      """
      :returns: A Numpy array of [Dataset Length, Num Classes] with values corresponding to
                a number of objects at the current sample index.
      """

Why we want to have interface?

  • We can check whether dataset we are trying to use with sampler implements it (If not - we show nice error).
  • Sampler implementation gets a presence matrix which contains ALL necessary information to do sampling. No remapping indexes and other stuff.

@shaydeci
Copy link
Collaborator

First of all - great initiative on adding a samplers support! I feel this can be really helpful in many cases. Regarding the suggested implementation - I do have a few comments:

  1. I will start with a most simple one. I think we should have a generic wrapper around any sampler for DPP case: DistributedSamplerWrapper

Suppose we have a sampler (any subclass of a Sampler), then for DDP case we would do: sampler = DistributedSamplerWrapper(sampler) and that would make sampler compatible with DDP.

The suggested design of wrapper was proposed on pytorch forums, been tested for years and myself and does not require any special knowledge of underlying sampler implementation.

class DatasetFromSampler(Dataset):
    """Dataset to create indexes from `Sampler`.

    Args:
        sampler: PyTorch sampler
    """

    def __init__(self, sampler: Sampler):
        """Initialisation for DatasetFromSampler."""
        self.sampler = sampler
        self.sampler_list = None

    def __getitem__(self, index: int):
        """Gets element of the dataset.

        Args:
            index: index of the element in the dataset

        Returns:
            Single element by index
        """
        if self.sampler_list is None:
            self.sampler_list = list(self.sampler)
        return self.sampler_list[index]

    def __len__(self) -> int:
        """
        Returns:
            int: length of the dataset
        """
        return len(self.sampler)


class DistributedSamplerWrapper(DistributedSampler):
    """
    Wrapper over `Sampler` for distributed training.
    Allows you to use any sampler in distributed mode.

    It is especially useful in conjunction with
    `torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSamplerWrapper instance as a DataLoader
    sampler, and load a subset of subsampled data of the original dataset
    that is exclusive to it.

    .. note::
        Sampler is assumed to be of constant size.
    """

    def __init__(
        self,
        sampler,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
    ):
        """

        Args:
            sampler: Sampler used for subsampling
            num_replicas (int, optional): Number of processes participating in
              distributed training
            rank (int, optional): Rank of the current process
              within ``num_replicas``
            shuffle (bool, optional): If true (default),
              sampler will shuffle the indices
        """
        super(DistributedSamplerWrapper, self).__init__(
            DatasetFromSampler(sampler),
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
        )
        self.sampler = sampler

    def __iter__(self):

        self.dataset = DatasetFromSampler(self.sampler)
        indexes_of_indexes = super().__iter__()
        subsampler_indexes = self.dataset
        return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))

I like this ^^

@NatanBagrov
Copy link
Contributor Author

Secondly, I think we should introduce some sort of the interface:

class HasSamplingInformation(ABC):
   def getLabelsPresense() -> np.ndarray:
      """
      :returns: A Numpy array of [Dataset Length, Num Classes] with values corresponding to
                a number of objects at the current sample index.
      """

Why we want to have interface?

  • We can check whether dataset we are trying to use with sampler implements it (If not - we show nice error).
  • Sampler implementation gets a presence matrix which contains ALL necessary information to do sampling. No remapping indexes and other stuff.

This is a nice idea, I wonder if an array is too much compared to a generator. A fixed array, when we take Objects365 a an example is of size: 1.7M dataset size * 365 classes * 4 bytes ~= 2.5GB
Another possible downside is cases where SG users have their dataset, and will not get this feature. On the other hand, perhaps it is fair to ask to implement something to get value.

@NatanBagrov
Copy link
Contributor Author

Closing because this PR has turned into two: #1865 #1856

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants