Skip to content

Commit

Permalink
Update documentation, add test (#27)
Browse files Browse the repository at this point in the history
* Update documentation, add test
* Remove backticks in lieu of quotes
  • Loading branch information
dirmeier committed Feb 1, 2024
1 parent 4b052c6 commit 6ca6602
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 17 deletions.
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ repos:
args: ["--ignore-missing-imports"]
files: "(surjectors|examples)"

- repo: https://github.com/jorisroovers/gitlint
rev: v0.19.1
hooks:
- id: gitlint
- id: gitlint-ci

- repo: https://github.com/pycqa/pydocstyle
rev: 6.1.1
hooks:
Expand Down
19 changes: 15 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ You can, for instance, construct a simple normalizing flow like this:

```python
import distrax
from jax import numpy as jnp
import haiku as hk
from jax import numpy as jnp, random as jr
from surjectors import Slice, LULinear, Chain
from surjectors import TransformedDistribution
from surjectors.nn import make_mlp
Expand All @@ -37,9 +38,19 @@ def decoder_fn(n_dim):
return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
return _fn

base_distribution = distrax.Normal(jnp.zeros(5), jnp.ones(5))
transform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
pushforward = TransformedDistribution(base_distribution, transform)
@hk.without_apply_rng
@hk.transform
def flow(x):
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
)
transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
pushforward = TransformedDistribution(base_distribution, transform)
return pushforward.log_prob(x)

x = jr.normal(jr.PRNGKey(1), (1, 10))
params = flow.init(jr.PRNGKey(2), x)
lp = flow.apply(params, x)
```

More self-contained examples can be found in [examples](https://github.com/dirmeier/surjectors/tree/main/examples).
Expand Down
20 changes: 15 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ Example
You can, for instance, construct a simple normalizing flow like this:

>>> import distrax
>>> from jax import numpy as jnp
>>> from surjectors import Slice, LULinear, Chain
>>> import haiku as hk
>>> from jax import numpy as jnp, random as jr
>>> from surjectors import TransformedDistribution
>>> from surjectors.nn import make_mlp
>>>
Expand All @@ -38,9 +38,19 @@ You can, for instance, construct a simple normalizing flow like this:
>>> return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))
>>> return _fn
>>>
>>> base_distribution = distrax.Normal(jnp.zeros(5), jnp.ones(1))
>>> transform = Chain([Slice(10, decoder_fn(10)), LULinear(5)])
>>> pushforward = TransformedDistribution(base_distribution, transform)
>>> @hk.without_apply_rng
>>> @hk.transform
>>> def flow(x):
>>> base_distribution = distrax.Independent(
>>> distrax.Normal(jnp.zeros(5), jnp.ones(5)), 1
>>> )
>>> transform = Chain([Slice(5, decoder_fn(5)), LULinear(5)])
>>> pushforward = TransformedDistribution(base_distribution, transform)
>>> return pushforward.log_prob(x)
>>>
>>> x = jr.normal(jr.PRNGKey(1), (1, 10))
>>> params = flow.init(jr.PRNGKey(2), x)
>>> lp = flow.apply(params, x)

The flow is constructed using three objects: a base distribution, a transformation, and a transformed distribution.

Expand Down
4 changes: 2 additions & 2 deletions surjectors/_src/bijectors/masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(
"""
if event_ndims is not None and event_ndims < inner_event_ndims:
raise ValueError(
f"`event_ndims={event_ndims}` should be at least as"
f" large as `inner_event_ndims={inner_event_ndims}`."
f"'event_ndims={event_ndims}' should be at least as"
f" large as 'inner_event_ndims={inner_event_ndims}'."
)
if not isinstance(conditioner, MADE):
raise ValueError(
Expand Down
60 changes: 60 additions & 0 deletions surjectors/_src/surjectors/slice_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# pylint: skip-file

import distrax
import haiku as hk
from jax import numpy as jnp
from jax import random

from surjectors import Slice, TransformedDistribution
from surjectors._src.conditioners.mlp import make_mlp


def _decoder_fn(n_dim):
def _fn(z):
params = make_mlp([32, 32, n_dim * 2])(z)
means, log_scales = jnp.split(params, 2, -1)
return distrax.Independent(distrax.Normal(means, jnp.exp(log_scales)))

return _fn


def _base_distribution_fn(n_latent):
base_distribution = distrax.Independent(
distrax.Normal(jnp.zeros(n_latent), jnp.ones(n_latent)),
reinterpreted_batch_ndims=1,
)
return base_distribution


def make_surjector(n_dimension, n_latent):
def _transformation_fn():
slice = Slice(n_latent, _decoder_fn(n_dimension - n_latent))
return slice

def _flow(method, **kwargs):
td = TransformedDistribution(
_base_distribution_fn(n_latent), _transformation_fn()
)
return td(method, **kwargs)

td = hk.transform(_flow)
return td


def test_slice():
n_dimension, n_latent = 10, 3
y = random.normal(random.PRNGKey(1), shape=(10, n_dimension))

flow = make_surjector(n_dimension, n_latent)
params = flow.init(random.PRNGKey(0), method="log_prob", y=y)
_ = flow.apply(params, None, method="log_prob", y=y)


def test_conditional_slice():
n_dimension, n_latent = 10, 3
y = random.normal(random.PRNGKey(1), shape=(10, n_dimension))
x = random.normal(random.PRNGKey(1), shape=(10, 2))

flow = make_surjector(n_dimension, n_latent)
params = flow.init(random.PRNGKey(0), method="log_prob", y=y, x=x)
_ = flow.apply(params, None, method="log_prob", y=y, x=x)

0 comments on commit 6ca6602

Please sign in to comment.