Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient caching vs Model dropout #12

Open
harveyp123 opened this issue Feb 5, 2024 · 3 comments
Open

Gradient caching vs Model dropout #12

harveyp123 opened this issue Feb 5, 2024 · 3 comments

Comments

@harveyp123
Copy link

harveyp123 commented Feb 5, 2024

The GC-DPR has two steps

  1. The first step did a full batch forward without gradient, to get the full batch contrastive learning loss and corresponding embedding gradient.
  2. The second step conduct mini-batch forward, and assign the embedding gradient, then do backward. The mini-batch will loop through the full batch to computing all gradient and accumulate.

However, during the computation, there might be one issues:

  1. The backbone model has randomized dropout process, the dropout will make the 1 & 2 to be inconsistent. 1's dropout process will be different from 2, so 1's gradient can not be directly applied to 2. 2's gradient shall be calculated again for every mini-batch. This bug can be fixed using some more sophisticated operation to make sure 1&2 to be consistent.
@harveyp123
Copy link
Author

In short, in the second for loop, for everything minibatch query and passage loss backward, you put the query and passage embedding into the original batch, and calculate the gradient for the current query/passage, so you can make sure the dropout behavior doesn't change your gradient too much.

@luyug
Copy link
Owner

luyug commented Feb 23, 2024

In our train code, the random states are snapshot using the RandContext class

class RandContext:
def __init__(self, *tensors):
self.fwd_cpu_state = torch.get_rng_state()
self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
def __enter__(self):
self._fork = torch.random.fork_rng(
devices=self.fwd_gpu_devices,
enabled=True
)
self._fork.__enter__()
torch.set_rng_state(self.fwd_cpu_state)
set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)
def __exit__(self, exc_type, exc_val, exc_tb):
self._fork.__exit__(exc_type, exc_val, exc_tb)
self._fork = None
in the first fwd and restored at the beggining of the 2nd, so what you described shouldn't be a problem.

@harveyp123
Copy link
Author

harveyp123 commented Feb 23, 2024

Oh, okay, I was using deepspeed + gradient caching, the model is wrapped into a deepspeed defined object, and RandContext doesn't work on my side. But it's good to learn from your code : )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants