Skip to content

pierrot-lc/anime-diffusion

Repository files navigation

Simple DDPM

This repository is a jax implementation of the Denoising Diffusion Probabilistic Models paper. The model is trained on the anime face kaggle dataset.

The implementation should be simple to follow. The forward/backward diffusion implementation can be found in the file src/diffusion.py.

Backward diffusion process Generation examples

Model architecture

The architecture is the U-ViT taken from this paper. I derived from the usual U-nets because I find U-ViT conceptually simpler and easier to manage (less hyperparameters, less choices overall).

U-ViT architecture

Installation

Runtime dependencies

You need python 3.12 and pdm. Use pdm sync to download all python dependencies.

Download training data

You can download the dataset using the kaggle cli. Then you can use the following:

mkdir -p datasets
kaggle datasets download -d splcher/animefacedataset
unzip animefacedataset.zip
mv images datasets/anime-faces
rm animefacedataset.zip

Training

Launch training

This repository is using hydra to set the configurations and wandb to log training metrics. You can find the default hyperparameters in the default.yaml files in the config/ directory.

python3 main.py mode=[online|offline] dataset=[default|small] model=[default|small]

Training details

The main model has 15M parameters and has been trained for 30 hours on a laptop RTX 3080. But I suspect that you could get great result with a smaller model as well.

What next?

One could implement the score-matching or differential version of the diffusion. For DDPM, a rapid enhancement would be to fix the schedule following the ideas in this paper.

Sources

Huge shout out to Stanley H. Chan for the tutorial on diffusion models. This is what motivated me to build a minimal implementation and helped me a lot understanding the equations.

About

Implementation of a basic diffusion model in Jax.

Resources

License

Stars

Watchers

Forks

Packages

No packages published