Skip to content

Gated MLP from "Pay Attention to MLPs" in jax

Notifications You must be signed in to change notification settings

sooheon/gmlp-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 

Repository files navigation

gMLP Jax

gMLPs with Spatial Gating Units in Jax (Flax).

Flax linen API maps pretty much 1:1 with pseudocode in the paper.

Usage

import jax
import jax.numpy as jnp
from gmlp import gMLP

rng = jax.random.PRNGKey(42)
model = gMLP(512, attn_features=64)
out, params = model.init_with_output(rng, jnp.zeros((1, 128, 320)))
jax.tree_map(jnp.shape, params)

About

Gated MLP from "Pay Attention to MLPs" in jax

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages