Skip to content

A reference implementation of MoE LLM in Jax and Haiku

Notifications You must be signed in to change notification settings

AllenHW/JAX-MoE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

About

This is a reference implementation of an Mixture-of-Experts Transformer (MoE) model in Jax and Haiku. The original motivation was to personally learn Jax and to understand the details within MoE. I've documented the implementation so that it could be helpful for others who are going through a similar process.

Key Features

The techniques used are fairly standard among state-of-the-art models. So far it supports:

  • Parallel implementation of the Byte Pair Encoding (BPE) algorithm.
  • Grouped Query Attention (GQA)
    • paper
    • A generalization of Multi-Query Attention, where multiple query heads share a single key and value head. In GQA there are multiple key and value heads, each shared by a group of query heads. Attention is computed between each key/value head and the query heads within the group. The results for all query heads across all groups are then concatenated and projected down to a lower dimension by the projection matrix (the same way as in multi-head attention.)
  • Rotary Embedding
  • Use KV cache to keep the computed key and value heads for previous tokens.
  • Mixture-of-Experts (MoE) block with capacity
    • We use a token-choice model where each token chooses top K experts computed by the router.
    • Tokens that fall outside the capacity of an expert will be passed through as a residual connection.
    • Computatin for each expert is parallelized, potentially on separate devices.

TODO

  • Implement training loop
  • Find a dataset to use for training
  • Generation code
  • Large scale training, potentially on multi GPUs
  • Implement Flash Attention??
  • Attempt some intrepretability techniques

References

About

A reference implementation of MoE LLM in Jax and Haiku

Topics

Resources

Stars

Watchers

Forks