Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beta-VAE #728

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions beta-vae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Basic beta-VAE Example

This is an implementation of the paper [beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework](https://openreview.net/pdf?id=Sy2fzU9gl).

We did experimentation on the FashionMNIST dataset, using a very simple convolution neural network architecture.

```bash
pip install -r requirements.txt
python main.py
```
The main.py script accepts the following arguments:

```bash
optional arguments:
--batch-size input batch size for training (default: 128)
--epochs number of epochs to train (default: 10)
--no-cuda enables CUDA training
--seed random seed (default: 1)
--log-interval how many batches to wait before logging training status
```
203 changes: 203 additions & 0 deletions beta-vae/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image


from typing import List, Callable, Union, Any, TypeVar, Tuple
Tensor = TypeVar('torch.tensor')



parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if args.cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

train_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST('../data', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST('../data', train=False, transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True, **kwargs)


class BetaVAE(nn.Module):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
beta: int = 4):
super(BetaVAE, self).__init__()
self.latent_dim = latent_dim
self.beta = beta

modules = []
if hidden_dims is None:
hidden_dims = [16, 4]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, padding = 1),
nn.ReLU(),
nn.MaxPool2d(2, 2))
)
in_channels = h_dim

self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*7*7, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*7*7, latent_dim)

# Build Decoder
modules = []

self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1]*7*7)

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=2,
stride = 2),
nn.ReLU())
)

self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
1,
kernel_size=2,
stride=2),
nn.Sigmoid())


def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)

# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)

return [mu, log_var]

def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 4, 7, 7)
result = self.decoder(result)
result = self.final_layer(result)
return result

def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu

def forward(self, input: Tensor, **kwargs) -> Tensor:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var

def loss_function(self,
*args,
**kwargs) -> dict:
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
criterion = nn.BCELoss(reduction='sum')
recons_loss =criterion(recons, input)

kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

return recons_loss + self.beta * kld_loss

model = BetaVAE(1, 100, beta=5).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = model.loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))

print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
model.eval()
test_loss = 0
with torch.no_grad():
for i, (data, _) in enumerate(test_loader):
data = data.to(device)
recon_batch, mu, logvar = model(data)
test_loss += model.loss_function(recon_batch, data, mu, logvar).item()
if i == 0:
n = min(data.size(0), 8)
# Allow to save the orginal image and the generated one as we go
comparison = torch.cat([data[:n],
recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
save_image(comparison.cpu(),
'results/reconstruction_' + str(epoch) + '.png', nrow=n)

test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))

if __name__ == "__main__":
for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
4 changes: 4 additions & 0 deletions beta-vae/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch
torchvision
tqdm
six
Binary file added beta-vae/results/reconstruction_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added beta-vae/results/reconstruction_10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added beta-vae/results/reconstruction_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added beta-vae/results/reconstruction_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added beta-vae/results/reconstruction_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added beta-vae/results/reconstruction_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added beta-vae/results/reconstruction_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added beta-vae/results/reconstruction_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added beta-vae/results/reconstruction_8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added beta-vae/results/reconstruction_9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions vae/.vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.pythonPath": "/home/aims/anaconda3/envs/aims/bin/python"
}