-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Malte lig 4894 fix gatherlayer #1531
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1531 +/- ##
==========================================
+ Coverage 81.76% 81.96% +0.20%
==========================================
Files 144 144
Lines 6092 6094 +2
==========================================
+ Hits 4981 4995 +14
+ Misses 1111 1099 -12 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Let some comments regarding typos but otherwise everything looks good :)
Co-authored-by: guarin <43336610+guarin@users.noreply.github.com>
Co-authored-by: guarin <43336610+guarin@users.noreply.github.com>
grad_out = torch.empty_like(input) | ||
grad_out[:] = grads[dist.get_rank()] | ||
def backward(ctx, *grads) -> torch.Tensor: | ||
all_gradients = torch.stack(grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code here can be optimized by avoiding stacking (see https://github.com/cellarium-ai/cellarium-ml/blob/220ba90b47378c99d4c08b9d91c5c31b796cb3ca/cellarium/ml/distributed/gather.py#L17-L21):
@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
grad_out = grads[dist.get_rank()].contiguous()
dist.all_reduce(grad_out, op=dist.ReduceOp.SUM)
return grad_out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for bringing this up!
Interestingly, the code you proposed causes the test to fail. I have tested it using n_devices=2
and by calling TestGatherLayer().test(None)
. Thus I would rather keep the existing code, as it is the only one causing the training to come to exactly the same values as with n_devices=1
the code you proposed:
def backward(ctx, *grads) -> torch.Tensor:
print(f"grads: {grads}")
grad_out = grads[dist.get_rank()].contiguous()
print(f"grad_out before reduce: {grad_out}")
dist.all_reduce(grad_out, op=dist.ReduceOp.SUM)
print(f"grad_out after reduce: {grad_out}")
return grad_out
grads: (tensor([[-0.0576, -0.0557, -0.0831, -0.0729],
[-0.0867, -0.0694, -0.0778, -0.0508],
[-0.0852, -0.0687, -0.0781, -0.0520],
[ 0.0289, -0.0130, -0.0879, -0.1265]]), tensor([[0.0354, 0.0653, 0.0552, 0.0957],
[0.0354, 0.0651, 0.0548, 0.0950],
[0.0353, 0.0649, 0.0546, 0.0946],
[0.0352, 0.0647, 0.0542, 0.0939]]))grads: (tensor([[0.0454, 0.0753, 0.0521, 0.0859],
[0.0457, 0.0759, 0.0524, 0.0865],
[0.0457, 0.0759, 0.0524, 0.0865],
[0.0393, 0.0649, 0.0451, 0.0740]]), tensor([[-0.0629, -0.0525, -0.0833, -0.0704],
[-0.0692, -0.0552, -0.0817, -0.0648],
[-0.0726, -0.0567, -0.0807, -0.0617],
[-0.0773, -0.0587, -0.0794, -0.0574]]))
grad_out before reduce: tensor([[-0.0576, -0.0557, -0.0831, -0.0729],
[-0.0867, -0.0694, -0.0778, -0.0508],
[-0.0852, -0.0687, -0.0781, -0.0520],
[ 0.0289, -0.0130, -0.0879, -0.1265]])
grad_out before reduce: tensor([[-0.0629, -0.0525, -0.0833, -0.0704],
[-0.0692, -0.0552, -0.0817, -0.0648],
[-0.0726, -0.0567, -0.0807, -0.0617],
[-0.0773, -0.0587, -0.0794, -0.0574]])
grad_out after reduce: tensor([[-0.1205, -0.1082, -0.1664, -0.1434],
[-0.1559, -0.1246, -0.1595, -0.1156],
[-0.1579, -0.1254, -0.1588, -0.1136],
[-0.0484, -0.0718, -0.1674, -0.1839]])
grad_out after reduce: tensor([[-0.1205, -0.1082, -0.1664, -0.1434],
[-0.1559, -0.1246, -0.1595, -0.1156],
[-0.1579, -0.1254, -0.1588, -0.1136],
[-0.0484, -0.0718, -0.1674, -0.1839]])
the current code
def backward(ctx, *grads) -> torch.Tensor:
print(f"grads: {grads}")
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
print(f"all_gradients after reduce: {all_gradients}")
grad_out = all_gradients[dist.get_rank()]
print(f"grad_out: {grad_out}")
return grad_out
grads: (tensor([[-0.0576, -0.0557, -0.0831, -0.0729],
[-0.0867, -0.0694, -0.0778, -0.0508],
[-0.0852, -0.0687, -0.0781, -0.0520],
[ 0.0289, -0.0130, -0.0879, -0.1265]]), tensor([[0.0354, 0.0653, 0.0552, 0.0957],
[0.0354, 0.0651, 0.0548, 0.0950],
[0.0353, 0.0649, 0.0546, 0.0946],
[0.0352, 0.0647, 0.0542, 0.0939]]))
grads: (tensor([[0.0454, 0.0753, 0.0521, 0.0859],
[0.0457, 0.0759, 0.0524, 0.0865],
[0.0457, 0.0759, 0.0524, 0.0865],
[0.0393, 0.0649, 0.0451, 0.0740]]), tensor([[-0.0629, -0.0525, -0.0833, -0.0704],
[-0.0692, -0.0552, -0.0817, -0.0648],
[-0.0726, -0.0567, -0.0807, -0.0617],
[-0.0773, -0.0587, -0.0794, -0.0574]]))
all_gradients after reduce: tensor([[[-0.0121, 0.0196, -0.0311, 0.0130],
[-0.0410, 0.0065, -0.0254, 0.0357],
[-0.0395, 0.0072, -0.0257, 0.0346],
[ 0.0682, 0.0518, -0.0428, -0.0525]],
[[-0.0275, 0.0129, -0.0281, 0.0253],
[-0.0339, 0.0099, -0.0268, 0.0303],
[-0.0373, 0.0083, -0.0261, 0.0329],
[-0.0421, 0.0059, -0.0252, 0.0365]]])
all_gradients after reduce: tensor([[[-0.0121, 0.0196, -0.0311, 0.0130],
[-0.0410, 0.0065, -0.0254, 0.0357],
[-0.0395, 0.0072, -0.0257, 0.0346],
[ 0.0682, 0.0518, -0.0428, -0.0525]],
[[-0.0275, 0.0129, -0.0281, 0.0253],
[-0.0339, 0.0099, -0.0268, 0.0303],
[-0.0373, 0.0083, -0.0261, 0.0329],
[-0.0421, 0.0059, -0.0252, 0.0365]]])
grad_out: tensor([[-0.0121, 0.0196, -0.0311, 0.0130],
[-0.0410, 0.0065, -0.0254, 0.0357],
[-0.0395, 0.0072, -0.0257, 0.0346],
[ 0.0682, 0.0518, -0.0428, -0.0525]])
grad_out: tensor([[-0.0275, 0.0129, -0.0281, 0.0253],
[-0.0339, 0.0099, -0.0268, 0.0303],
[-0.0373, 0.0083, -0.0261, 0.0329],
[-0.0421, 0.0059, -0.0252, 0.0365]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, thanks! Each gradient in grads needs to be reduced across all ranks.
I initially had the code below that was working but then switched to the one above without properly testing it with the model.
@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
new_grads = []
for grad in grads:
grad = grad.contiguous()
dist.all_reduce(grad, op=dist.ReduceOp.SUM)
new_grads.append(grad)
grad_out = new_grads[dist.get_rank()]
return grad_out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right @ordabayevy, the 2nd method you proposed works.
I think it comes from the fact that an all_reduce
is always needed by each rank for synchronisation. E.g. rank_0
only needs new_grads[0]
. However, to get the synchronised/reduced new_grads[0]
, all ranks need to synchronise/reduce new_grads[0]
. Thus both rank_0
and rank_1
need to synchronise/reduce both new_grads[0] and new_grads[1] each.
def backward(ctx, *grads) -> torch.Tensor:
new_grads = []
for grad in grads:
grad = grad.contiguous()
dist.all_reduce(grad, op=dist.ReduceOp.SUM)
new_grads.append(grad)
print(f"rank: {dist.get_rank()}, new_grads: {new_grads}")
grad_out = new_grads[dist.get_rank()]
print(f"rank: {dist.get_rank()}, grad_out: {grad_out}")
return grad_out
rank: 0, new_grads: [tensor([[-0.0121, 0.0196, -0.0311, 0.0130],
[-0.0410, 0.0065, -0.0254, 0.0357],
[-0.0395, 0.0072, -0.0257, 0.0346],
[ 0.0682, 0.0518, -0.0428, -0.0525]]), tensor([[-0.0275, 0.0129, -0.0281, 0.0253],
[-0.0339, 0.0099, -0.0268, 0.0303],
[-0.0373, 0.0083, -0.0261, 0.0329],
[-0.0421, 0.0059, -0.0252, 0.0365]])]rank: 1, new_grads: [tensor([[-0.0121, 0.0196, -0.0311, 0.0130],
[-0.0410, 0.0065, -0.0254, 0.0357],
[-0.0395, 0.0072, -0.0257, 0.0346],
[ 0.0682, 0.0518, -0.0428, -0.0525]]), tensor([[-0.0275, 0.0129, -0.0281, 0.0253],
[-0.0339, 0.0099, -0.0268, 0.0303],
[-0.0373, 0.0083, -0.0261, 0.0329],
[-0.0421, 0.0059, -0.0252, 0.0365]])]
rank: 1, grad_out: tensor([[-0.0275, 0.0129, -0.0281, 0.0253],
[-0.0339, 0.0099, -0.0268, 0.0303],
[-0.0373, 0.0083, -0.0261, 0.0329],
[-0.0421, 0.0059, -0.0252, 0.0365]])
rank: 0, grad_out: tensor([[-0.0121, 0.0196, -0.0311, 0.0130],
[-0.0410, 0.0065, -0.0254, 0.0357],
[-0.0395, 0.0072, -0.0257, 0.0346],
[ 0.0682, 0.0518, -0.0428, -0.0525]])
closes #1528
Description
Fixes the
GatherLayer
by using the implementation from solo learn.Tests
Adds a test for the
GatherLayer
by testing a model using theNTXentLoss
criterion. It compares the training behaviour for these two cases and ensures that it is exactly the same:The test is in the new file
test_dist__gather.py
. This is needed, because using a DDPStrategy causes the file to be executed once per device.Before the fix, the test failed.
Next tests
This test only tests the
NTXentLoss
criterion, the other models need to be tested as well.Testing the full SimCLR model
I also tried to have similar test when using a SimCLR model. However, it is extremely hard to get exactly the same training when using it.
Randomness causes different behaviour between n_devices=2 and n_devices=1
Results:
Using the SimCLR transform leads to different behaviour when using n_devices=2 compared to only 1 device. Even seeding does not help. This is caused by the different number of samples and thus the different random seeds. E.g.
Thus only removing randomness makes the output of the dataloader the same for the n_devices=2 and n_devices=1 cases.
The same problem also applies to any randomness in the model itself, e.g. in dropout layers.
Batch normalization causes different behaviour between n_devices=2 and n_devices=1
Batch normalization or any other operation using information from other samples in the same batch behaves differently when using
n_devices=2 & batch_size=4
compared ton_devices=1 & batch_size=8
. The batch normalisation would need to be synchronised as well for this to work.As pointed out by Guarin, we could use SyncBatchNorm to avoid this: https://lightning.ai/docs/pytorch/stable/common/trainer.html#sync-batchnorm