Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Malte lig 4894 fix gatherlayer #1531

Merged
merged 16 commits into from
May 3, 2024
Merged

Conversation

MalteEbner
Copy link
Contributor

@MalteEbner MalteEbner commented May 2, 2024

closes #1528

Description

Fixes the GatherLayer by using the implementation from solo learn.

Tests

Adds a test for the GatherLayer by testing a model using the NTXentLoss criterion. It compares the training behaviour for these two cases and ensures that it is exactly the same:

  1. n_devices=1, batch_size=8
  2. n_devices=2, batch_size=4

The test is in the new file test_dist__gather.py. This is needed, because using a DDPStrategy causes the file to be executed once per device.
Before the fix, the test failed.

Next tests

This test only tests the NTXentLoss criterion, the other models need to be tested as well.

Testing the full SimCLR model

I also tried to have similar test when using a SimCLR model. However, it is extremely hard to get exactly the same training when using it.

Randomness causes different behaviour between n_devices=2 and n_devices=1

Results:

Using the SimCLR transform leads to different behaviour when using n_devices=2 compared to only 1 device. Even seeding does not help. This is caused by the different number of samples and thus the different random seeds. E.g.

  • n_devices = 1, batch_size=8 processes the samples in order 0, 1, ... 7 and thus uses seed_0, seed_1, .... seed_7. Thus sample_1 uses seed_1
  • n_devices = 2, batch_size=4 processes the samples in order 0, 2, 4, 6 (device 0) and 1, 3, 5, 7 (device 1). Each of them uses seed_0, seed_1, ... in parallel, as these are two different process with their own seeding each. Thus sample_1 gets seed_0, which makes it differ.

Thus only removing randomness makes the output of the dataloader the same for the n_devices=2 and n_devices=1 cases.

The same problem also applies to any randomness in the model itself, e.g. in dropout layers.

Batch normalization causes different behaviour between n_devices=2 and n_devices=1

Batch normalization or any other operation using information from other samples in the same batch behaves differently when using n_devices=2 & batch_size=4 compared to n_devices=1 & batch_size=8. The batch normalisation would need to be synchronised as well for this to work.
As pointed out by Guarin, we could use SyncBatchNorm to avoid this: https://lightning.ai/docs/pytorch/stable/common/trainer.html#sync-batchnorm

Copy link

codecov bot commented May 2, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.96%. Comparing base (ec9f620) to head (64ff90d).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1531      +/-   ##
==========================================
+ Coverage   81.76%   81.96%   +0.20%     
==========================================
  Files         144      144              
  Lines        6092     6094       +2     
==========================================
+ Hits         4981     4995      +14     
+ Misses       1111     1099      -12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@guarin guarin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Let some comments regarding typos but otherwise everything looks good :)

lightly/utils/dist.py Outdated Show resolved Hide resolved
tests/utils/test_dist__gather.py Show resolved Hide resolved
tests/utils/test_dist__gather.py Outdated Show resolved Hide resolved
MalteEbner and others added 2 commits May 2, 2024 16:01
Co-authored-by: guarin <43336610+guarin@users.noreply.github.com>
Co-authored-by: guarin <43336610+guarin@users.noreply.github.com>
grad_out = torch.empty_like(input)
grad_out[:] = grads[dist.get_rank()]
def backward(ctx, *grads) -> torch.Tensor:
all_gradients = torch.stack(grads)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code here can be optimized by avoiding stacking (see https://github.com/cellarium-ai/cellarium-ml/blob/220ba90b47378c99d4c08b9d91c5c31b796cb3ca/cellarium/ml/distributed/gather.py#L17-L21):

    @staticmethod
    def backward(ctx, *grads) -> torch.Tensor:
        grad_out = grads[dist.get_rank()].contiguous()
        dist.all_reduce(grad_out, op=dist.ReduceOp.SUM)
        return grad_out

Copy link
Contributor Author

@MalteEbner MalteEbner May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing this up!

Interestingly, the code you proposed causes the test to fail. I have tested it using n_devices=2 and by calling TestGatherLayer().test(None). Thus I would rather keep the existing code, as it is the only one causing the training to come to exactly the same values as with n_devices=1

the code you proposed:

    def backward(ctx, *grads) -> torch.Tensor: 
        print(f"grads: {grads}")
        grad_out = grads[dist.get_rank()].contiguous()
        print(f"grad_out before reduce: {grad_out}")
        dist.all_reduce(grad_out, op=dist.ReduceOp.SUM)
        print(f"grad_out after reduce: {grad_out}")
        return grad_out
grads: (tensor([[-0.0576, -0.0557, -0.0831, -0.0729],
        [-0.0867, -0.0694, -0.0778, -0.0508],
        [-0.0852, -0.0687, -0.0781, -0.0520],
        [ 0.0289, -0.0130, -0.0879, -0.1265]]), tensor([[0.0354, 0.0653, 0.0552, 0.0957],
        [0.0354, 0.0651, 0.0548, 0.0950],
        [0.0353, 0.0649, 0.0546, 0.0946],
        [0.0352, 0.0647, 0.0542, 0.0939]]))grads: (tensor([[0.0454, 0.0753, 0.0521, 0.0859],
        [0.0457, 0.0759, 0.0524, 0.0865],
        [0.0457, 0.0759, 0.0524, 0.0865],
        [0.0393, 0.0649, 0.0451, 0.0740]]), tensor([[-0.0629, -0.0525, -0.0833, -0.0704],
        [-0.0692, -0.0552, -0.0817, -0.0648],
        [-0.0726, -0.0567, -0.0807, -0.0617],
        [-0.0773, -0.0587, -0.0794, -0.0574]]))

grad_out before reduce: tensor([[-0.0576, -0.0557, -0.0831, -0.0729],
        [-0.0867, -0.0694, -0.0778, -0.0508],
        [-0.0852, -0.0687, -0.0781, -0.0520],
        [ 0.0289, -0.0130, -0.0879, -0.1265]])
grad_out before reduce: tensor([[-0.0629, -0.0525, -0.0833, -0.0704],
        [-0.0692, -0.0552, -0.0817, -0.0648],
        [-0.0726, -0.0567, -0.0807, -0.0617],
        [-0.0773, -0.0587, -0.0794, -0.0574]])
grad_out after reduce: tensor([[-0.1205, -0.1082, -0.1664, -0.1434],
        [-0.1559, -0.1246, -0.1595, -0.1156],
        [-0.1579, -0.1254, -0.1588, -0.1136],
        [-0.0484, -0.0718, -0.1674, -0.1839]])
grad_out after reduce: tensor([[-0.1205, -0.1082, -0.1664, -0.1434],
        [-0.1559, -0.1246, -0.1595, -0.1156],
        [-0.1579, -0.1254, -0.1588, -0.1136],
        [-0.0484, -0.0718, -0.1674, -0.1839]])

the current code

    def backward(ctx, *grads) -> torch.Tensor: 
        print(f"grads: {grads}")
        all_gradients = torch.stack(grads)
        dist.all_reduce(all_gradients)
        print(f"all_gradients after reduce: {all_gradients}")
        grad_out = all_gradients[dist.get_rank()]
        print(f"grad_out: {grad_out}")
        return grad_out 
grads: (tensor([[-0.0576, -0.0557, -0.0831, -0.0729],
        [-0.0867, -0.0694, -0.0778, -0.0508],
        [-0.0852, -0.0687, -0.0781, -0.0520],
        [ 0.0289, -0.0130, -0.0879, -0.1265]]), tensor([[0.0354, 0.0653, 0.0552, 0.0957],
        [0.0354, 0.0651, 0.0548, 0.0950],
        [0.0353, 0.0649, 0.0546, 0.0946],
        [0.0352, 0.0647, 0.0542, 0.0939]]))
grads: (tensor([[0.0454, 0.0753, 0.0521, 0.0859],
        [0.0457, 0.0759, 0.0524, 0.0865],
        [0.0457, 0.0759, 0.0524, 0.0865],
        [0.0393, 0.0649, 0.0451, 0.0740]]), tensor([[-0.0629, -0.0525, -0.0833, -0.0704],
        [-0.0692, -0.0552, -0.0817, -0.0648],
        [-0.0726, -0.0567, -0.0807, -0.0617],
        [-0.0773, -0.0587, -0.0794, -0.0574]]))
all_gradients after reduce: tensor([[[-0.0121,  0.0196, -0.0311,  0.0130],
         [-0.0410,  0.0065, -0.0254,  0.0357],
         [-0.0395,  0.0072, -0.0257,  0.0346],
         [ 0.0682,  0.0518, -0.0428, -0.0525]],

        [[-0.0275,  0.0129, -0.0281,  0.0253],
         [-0.0339,  0.0099, -0.0268,  0.0303],
         [-0.0373,  0.0083, -0.0261,  0.0329],
         [-0.0421,  0.0059, -0.0252,  0.0365]]])
all_gradients after reduce: tensor([[[-0.0121,  0.0196, -0.0311,  0.0130],
         [-0.0410,  0.0065, -0.0254,  0.0357],
         [-0.0395,  0.0072, -0.0257,  0.0346],
         [ 0.0682,  0.0518, -0.0428, -0.0525]],

        [[-0.0275,  0.0129, -0.0281,  0.0253],
         [-0.0339,  0.0099, -0.0268,  0.0303],
         [-0.0373,  0.0083, -0.0261,  0.0329],
         [-0.0421,  0.0059, -0.0252,  0.0365]]])
grad_out: tensor([[-0.0121,  0.0196, -0.0311,  0.0130],
        [-0.0410,  0.0065, -0.0254,  0.0357],
        [-0.0395,  0.0072, -0.0257,  0.0346],
        [ 0.0682,  0.0518, -0.0428, -0.0525]])
grad_out: tensor([[-0.0275,  0.0129, -0.0281,  0.0253],
        [-0.0339,  0.0099, -0.0268,  0.0303],
        [-0.0373,  0.0083, -0.0261,  0.0329],
        [-0.0421,  0.0059, -0.0252,  0.0365]])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, thanks! Each gradient in grads needs to be reduced across all ranks.

I initially had the code below that was working but then switched to the one above without properly testing it with the model.

    @staticmethod
    def backward(ctx, *grads) -> torch.Tensor:
        new_grads = []
        for grad in grads:
            grad = grad.contiguous()
            dist.all_reduce(grad, op=dist.ReduceOp.SUM)
            new_grads.append(grad)
        grad_out = new_grads[dist.get_rank()]
        return grad_out

Copy link
Contributor Author

@MalteEbner MalteEbner May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right @ordabayevy, the 2nd method you proposed works.

I think it comes from the fact that an all_reduce is always needed by each rank for synchronisation. E.g. rank_0 only needs new_grads[0]. However, to get the synchronised/reduced new_grads[0], all ranks need to synchronise/reduce new_grads[0]. Thus both rank_0 and rank_1 need to synchronise/reduce both new_grads[0] and new_grads[1] each.

    def backward(ctx, *grads) -> torch.Tensor: 
        new_grads = []
        for grad in grads:
            grad = grad.contiguous()
            dist.all_reduce(grad, op=dist.ReduceOp.SUM)
            new_grads.append(grad)
        print(f"rank: {dist.get_rank()}, new_grads: {new_grads}")
        grad_out = new_grads[dist.get_rank()]
        print(f"rank: {dist.get_rank()}, grad_out: {grad_out}")
        return grad_out
rank: 0, new_grads: [tensor([[-0.0121,  0.0196, -0.0311,  0.0130],
        [-0.0410,  0.0065, -0.0254,  0.0357],
        [-0.0395,  0.0072, -0.0257,  0.0346],
        [ 0.0682,  0.0518, -0.0428, -0.0525]]), tensor([[-0.0275,  0.0129, -0.0281,  0.0253],
        [-0.0339,  0.0099, -0.0268,  0.0303],
        [-0.0373,  0.0083, -0.0261,  0.0329],
        [-0.0421,  0.0059, -0.0252,  0.0365]])]rank: 1, new_grads: [tensor([[-0.0121,  0.0196, -0.0311,  0.0130],
        [-0.0410,  0.0065, -0.0254,  0.0357],
        [-0.0395,  0.0072, -0.0257,  0.0346],
        [ 0.0682,  0.0518, -0.0428, -0.0525]]), tensor([[-0.0275,  0.0129, -0.0281,  0.0253],
        [-0.0339,  0.0099, -0.0268,  0.0303],
        [-0.0373,  0.0083, -0.0261,  0.0329],
        [-0.0421,  0.0059, -0.0252,  0.0365]])]

rank: 1, grad_out: tensor([[-0.0275,  0.0129, -0.0281,  0.0253],
        [-0.0339,  0.0099, -0.0268,  0.0303],
        [-0.0373,  0.0083, -0.0261,  0.0329],
        [-0.0421,  0.0059, -0.0252,  0.0365]])
rank: 0, grad_out: tensor([[-0.0121,  0.0196, -0.0311,  0.0130],
        [-0.0410,  0.0065, -0.0254,  0.0357],
        [-0.0395,  0.0072, -0.0257,  0.0346],
        [ 0.0682,  0.0518, -0.0428, -0.0525]])

@MalteEbner MalteEbner enabled auto-merge (squash) May 3, 2024 19:22
@MalteEbner MalteEbner merged commit 80fa5b5 into master May 3, 2024
10 checks passed
@MalteEbner MalteEbner deleted the malte-lig-4894-fix-gatherlayer branch May 3, 2024 19:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug in GatherLayer.backward
3 participants