-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added a class balanced distributed sampler #1844
Conversation
|
||
|
||
@register_sampler(Samplers.DISTRIBUTED_DETECTION_CLASS_BALANCING) | ||
class DetectionClassBalancedDistributedSampler(DistributedSampler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a dedicated module for our samplers, please move it there.
return repeat_factors | ||
|
||
|
||
class IndexMappingDatasetWrapper(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel a bit unease about having a different dataset passed to the sampler then what we actually have in our data loader.
If anything happens under the hood (which I think does) things could go wrong.
I got a feeling @BloodAxe would be on the same page as me here, but what do you say ?
return torch.tensor([idx, 0]) # class 0 appears everywhere, other classes appear only once. | ||
|
||
|
||
class ClassBalancingTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add to suite.
return torch.tensor(idx) | ||
|
||
|
||
class DatasetIndexMappingTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add to suite.
from super_gradients.training.datasets.detection_datasets.detection_dataset_class_balancing_wrapper import DetectionClassBalancedDistributedSampler | ||
|
||
|
||
class DummyDetectionDataset(Dataset): # NOTE: we implement the needed stuff from DetectionDataset, but we do not inherit it because the ctor is massive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add to suite.
|
||
from super_gradients.training.datasets.balancing_classes_utils import get_repeat_factors | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a comprehensive test that actually runs in DDP is missing.
Please add one (can be done through the ci config - see sanity_tests workflow).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pff.. I agree. I'll try to implement it once we agree on the approach (i.e., is will be redundant if we decide on option 3 or 4)
2. For each category c, compute the category-level repeat factor: :math:`r(c) = max(1, sqrt(t/f(c)))` | ||
3. For each image I, compute the image-level repeat factor: :math:`r(I) = max_{c in I} r(c)` | ||
|
||
Returns a list of repeat factors (length = dataset_length). How to read: result[i] is a float, indicates the repeat factor of image i. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity - I guess this would mean that using Mosaic / mixup would mess things up here?
We use them in most of our recipes, so maybe there's something to tak in account here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, you are half correct :) I believe things should work ok because the balancing is done (currently) in the sampler-level, while mixup/mosaic are dataset-level. On the other hand, the sampler will "balance" and ask for more indices with scarce classes (this is good), but the dataset does not know that it should sample non uniformly, thus 3/4 images will be taken at random, regardless of class scarcity. Hope it makes sense...
for dataset_idx, repeat_factor in enumerate(repeat_factors): | ||
repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor)) | ||
|
||
super().__init__(dataset=IndexMappingDatasetWrapper(dataset, repeat_indices), *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following my precious comment regarding the dataset wrapper, I am in favor moving the logic introduced there (mapping etc) to this class.
Also the helper methods which I don't think have much context outside this speific sampler.
First of all - great initiative on adding a samplers support! I feel this can be really helpful in many cases.
Suppose we have a sampler (any subclass of a The suggested design of wrapper was proposed on pytorch forums, been tested for years and myself and does not require any special knowledge of underlying sampler implementation.
|
Secondly, I think we should introduce some sort of the interface:
Why we want to have interface?
|
I like this ^^ |
This is a nice idea, I wonder if an array is too much compared to a generator. A fixed array, when we take Objects365 a an example is of size: |
Context: in cases where some classes are less frequent than others, we'd want to sample images containing these classes more often.
Current implementation:
IndexMappingDatasetWrapper
- gets a dataset and a list of indices (mapping
), and wraps the dataset:wrapper[i] == self.dataset[self.mapping[item]]
DetectionClassBalancedDistributedSampler
that uses (1) with (2) for Detection datasets.Discussion:
Current implementation supports distributed mode. Non-distributed mode can also be supported, but will result more code.
Some alternative implementations I can spot:
__iter__
fromDistributedSampler
- it is possible, however, causes coupling and code duplication. Current impelentation (super().__init__(dataset=IndexMappingDatasetWrapper(dataset, repeat_indices), *args, **kwargs)
) is a bit less verbose.WeightedSampler
, however, it is not supported in distributed mode. There are some implementations ofDistributedWeightedSampler
in-the-wild. We can try these.inner_class
, no inheritDetectionDataset
because we don't want to pass all the__init__
parameters.DetectionDataset
and add a parameter for class balancing (i.e.,oversampling_threshold
) and do the heavy lifting insideDetectionDataset
. This is less verbose, however, adds more responsibility toDetectionDataset
(compared to Composition)Open for discussion :)