Skip to content

Commit

Permalink
Add perturb() to module.py
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Sep 23, 2022
1 parent 45c2955 commit 8c0a60a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
44 changes: 44 additions & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,50 @@ def __call__(self, x):
self.scope.put_variable(col, name, xs)
return True

def perturb(self, name: str, value: T, collection: str = 'perturbations') -> T:
"""Add an zero-value variable ('perturbation') to the intermediate value.
The gradient of `value` would be the same as the gradient of this
perturbation variable. Therefore, if you define your loss function with
both params and perturbations as standalone arguments, you can get the
intermediate gradients of `value` by running `jax.grad` on the perturbation
argument.
Note: this is an experimental API and may be tweaked later for better
performance and usability.
At its current stage, it creates extra dummy variables that occupies extra
memory space. Use it only to debug gradients in training.
Example::
import jax
import jax.numpy as jnp
import flax.linen as nn
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(3)(x)
x = self.perturb('dense3', x)
return nn.Dense(2)(x)
def loss(params, perturbations, inputs, targets):
variables = {'params': params, 'perturbations': perturbations}
preds = model.apply(variables, inputs)
return jnp.square(preds - targets).mean()
x = jnp.ones((2, 9))
y = jnp.ones((2, 2))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
print(intm_grads['dense3']) # ==> [[-1.456924 -0.44332537 0.02422847]
# [-1.456924 -0.44332537 0.02422847]]
"""
value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value
return value

def tabulate(
self,
rngs: Union[PRNGKey, RNGSequences],
Expand Down
23 changes: 23 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,29 @@ def __call__(self, x):
_, state = Foo().apply({}, 1, capture_intermediates=fn)
self.assertEqual(state, {'intermediates': {'Bar_0': {'test': (2,)}}})

def test_perturb(self):
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(10)(x)
x = self.perturb('before_multiply', x)
x = 4 * x
x = self.perturb('after_multiply', x)
return x

def loss(params, perturbations, inputs, targets):
variables = {'params': params, 'perturbations': perturbations}
preds = Foo().apply(variables, inputs)
return jnp.square(preds - targets).mean()

x = jax.random.uniform(jax.random.PRNGKey(1), shape=(10, ))
y = jax.random.uniform(jax.random.PRNGKey(2), shape=(10, ))
variables = Foo().init(jax.random.PRNGKey(0), x)
pred = Foo().apply(variables, x)
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
# activation * 4 so reverse gradient also * 4
self.assertTrue(all(intm_grads['after_multiply'] * 4 == intm_grads['before_multiply']))

def test_functional_apply(self):

class Foo(nn.Module):
Expand Down

0 comments on commit 8c0a60a

Please sign in to comment.