Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shouldn't accumulate_gradient pass rng_key? #286

Open
hrbigelow opened this issue Sep 8, 2023 · 0 comments
Open

Shouldn't accumulate_gradient pass rng_key? #286

hrbigelow opened this issue Sep 8, 2023 · 0 comments

Comments

@hrbigelow
Copy link

hrbigelow commented Sep 8, 2023

Hi,

I was looking at this code for accumulate_gradient and usually pass a params['rng_key'] as part of the state. But, with this as written, it would not feed a different rng_key for each accumulation step.

For example, I was thinking something like should be done instead:

    def acc_grad_and_loss(i, l_and_g, rng_key):
      imgs = jax.lax.dynamic_slice(images, (i * step_size, 0, 0, 0),
                                   (step_size,) + images.shape[1:])
      lbls = jax.lax.dynamic_slice(labels, (i * step_size, 0),
                                   (step_size, labels.shape[1]))
      # if loss has stochasticity, it should have a different random seed for each accumulation step
      params['rng_key'] = rng_key
      rng_key, = jax.random.split(rng_key, 1)
      li, gi = loss_and_grad_fn(params, imgs, lbls)
      l, g = l_and_g
      return (l + li, jax.tree_map(lambda x, y: x + y, g, gi), rng_key)

    l, g, rng_key = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g), rng_key)
    return jax.tree_map(lambda x: x / accum_steps, (l, g))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant