From 8c0a60ae904588182ba77dea1c72018ebfbcfc8b Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Thu, 22 Sep 2022 14:17:01 -0700 Subject: [PATCH] Add perturb() to module.py --- flax/linen/module.py | 44 ++++++++++++++++++++++++++++++++ tests/linen/linen_module_test.py | 23 +++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/flax/linen/module.py b/flax/linen/module.py index c0619a46f9..65c529bdda 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -1462,6 +1462,50 @@ def __call__(self, x): self.scope.put_variable(col, name, xs) return True + def perturb(self, name: str, value: T, collection: str = 'perturbations') -> T: + """Add an zero-value variable ('perturbation') to the intermediate value. + + The gradient of `value` would be the same as the gradient of this + perturbation variable. Therefore, if you define your loss function with + both params and perturbations as standalone arguments, you can get the + intermediate gradients of `value` by running `jax.grad` on the perturbation + argument. + + Note: this is an experimental API and may be tweaked later for better + performance and usability. + At its current stage, it creates extra dummy variables that occupies extra + memory space. Use it only to debug gradients in training. + + Example:: + + import jax + import jax.numpy as jnp + import flax.linen as nn + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(3)(x) + x = self.perturb('dense3', x) + return nn.Dense(2)(x) + + def loss(params, perturbations, inputs, targets): + variables = {'params': params, 'perturbations': perturbations} + preds = model.apply(variables, inputs) + return jnp.square(preds - targets).mean() + + x = jnp.ones((2, 9)) + y = jnp.ones((2, 2)) + model = Foo() + variables = model.init(jax.random.PRNGKey(0), x) + intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y) + print(intm_grads['dense3']) # ==> [[-1.456924 -0.44332537 0.02422847] + # [-1.456924 -0.44332537 0.02422847]] + + """ + value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value + return value + def tabulate( self, rngs: Union[PRNGKey, RNGSequences], diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index ee99762dd2..9e161f31a4 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -1417,6 +1417,29 @@ def __call__(self, x): _, state = Foo().apply({}, 1, capture_intermediates=fn) self.assertEqual(state, {'intermediates': {'Bar_0': {'test': (2,)}}}) + def test_perturb(self): + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + x = nn.Dense(10)(x) + x = self.perturb('before_multiply', x) + x = 4 * x + x = self.perturb('after_multiply', x) + return x + + def loss(params, perturbations, inputs, targets): + variables = {'params': params, 'perturbations': perturbations} + preds = Foo().apply(variables, inputs) + return jnp.square(preds - targets).mean() + + x = jax.random.uniform(jax.random.PRNGKey(1), shape=(10, )) + y = jax.random.uniform(jax.random.PRNGKey(2), shape=(10, )) + variables = Foo().init(jax.random.PRNGKey(0), x) + pred = Foo().apply(variables, x) + intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y) + # activation * 4 so reverse gradient also * 4 + self.assertTrue(all(intm_grads['after_multiply'] * 4 == intm_grads['before_multiply'])) + def test_functional_apply(self): class Foo(nn.Module):