You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was looking at this code for accumulate_gradient and usually pass a params['rng_key'] as part of the state. But, with this as written, it would not feed a different rng_key for each accumulation step.
For example, I was thinking something like should be done instead:
defacc_grad_and_loss(i, l_and_g, rng_key):
imgs=jax.lax.dynamic_slice(images, (i*step_size, 0, 0, 0),
(step_size,) +images.shape[1:])
lbls=jax.lax.dynamic_slice(labels, (i*step_size, 0),
(step_size, labels.shape[1]))
# if loss has stochasticity, it should have a different random seed for each accumulation stepparams['rng_key'] =rng_keyrng_key, =jax.random.split(rng_key, 1)
li, gi=loss_and_grad_fn(params, imgs, lbls)
l, g=l_and_greturn (l+li, jax.tree_map(lambdax, y: x+y, g, gi), rng_key)
l, g, rng_key=jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g), rng_key)
returnjax.tree_map(lambdax: x/accum_steps, (l, g))
The text was updated successfully, but these errors were encountered:
Hi,
I was looking at this code for accumulate_gradient and usually pass a params['rng_key'] as part of the state. But, with this as written, it would not feed a different rng_key for each accumulation step.
For example, I was thinking something like should be done instead:
The text was updated successfully, but these errors were encountered: