Skip to content

Commit

Permalink
move all exponential moving average logic to separate class, add an E…
Browse files Browse the repository at this point in the history
…MA for the vqgan-vae during training
  • Loading branch information
lucidrains committed Apr 7, 2022
1 parent 907467f commit d28d902
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 20 deletions.
114 changes: 95 additions & 19 deletions nuwa_pytorch/train_vqgan_vae.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import sqrt
import copy
from random import choice
from pathlib import Path
from shutil import rmtree
Expand Down Expand Up @@ -112,6 +113,58 @@ def __getitem__(self, index):
img = Image.open(path)
return self.transform(img)

# exponential moving average wrapper

class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)

self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
self.ema_update_every = ema_update_every

self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0.]))

def update(self):
self.step += 1

if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
return

if not self.initted:
self.ema_model.state_dict(self.online_model.state_dict())
self.initted.data.copy_(torch.Tensor([True]))

self.update_moving_average(self.ema_model, self.online_model)

def update_moving_average(ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new

for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)

for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)

def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)

# main trainer class

class VQGanVAETrainer(nn.Module):
def __init__(
self,
Expand All @@ -129,14 +182,20 @@ def __init__(
save_model_every = 1000,
results_folder = './results',
valid_frac = 0.05,
random_split_seed = 42
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_every = 10
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
image_size = vae.image_size

self.vae = vae
self.ema_vae = EMA(vae, ema_update_after_step = ema_update_after_step, ema_update_every = ema_update_every)

self.register_buffer('steps', torch.Tensor([0]))

self.steps = 0
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
Expand Down Expand Up @@ -197,6 +256,8 @@ def __init__(

def train_step(self):
device = next(self.vae.parameters()).device
steps = int(self.steps.item())

self.vae.train()

# logs
Expand All @@ -217,6 +278,7 @@ def train_step(self):
self.optim.step()
self.optim.zero_grad()


# update discriminator

if exists(self.vae.discr):
Expand All @@ -235,34 +297,48 @@ def train_step(self):

# log

print(f"{self.steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")

if not (self.steps % self.save_results_every):
self.vae.eval()
imgs = next(self.dl)
imgs = imgs.to(device)
# update exponential moving averaged generator

recons = self.vae(imgs)
nrows = int(sqrt(self.batch_size))
self.ema_vae.update()

imgs_and_recons = torch.stack((imgs, recons), dim = 0)
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
# sample results every so often

imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
if not (steps % self.save_results_every):
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
model.eval()

logs['reconstructions'] = grid
imgs = next(self.dl)
imgs = imgs.to(device)

save_image(grid, str(self.results_folder / f'{self.steps}.png'))
recons = model(imgs)
nrows = int(sqrt(self.batch_size))

print(f'{self.steps}: saving to {str(self.results_folder)}')
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')

if not (self.steps % self.save_model_every):
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))

logs['reconstructions'] = grid

save_image(grid, str(self.results_folder / f'{filename}.png'))

print(f'{steps}: saving to {str(self.results_folder)}')

# save model every so often

if not (steps % self.save_model_every):
state_dict = self.vae.state_dict()
model_path = str(self.results_folder / f'vae.{self.steps}.pt')
model_path = str(self.results_folder / f'vae.{steps}.pt')
torch.save(state_dict, model_path)

print(f'{self.steps}: saving model to {str(self.results_folder)}')
ema_state_dict = self.ema_vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
torch.save(ema_state_dict, model_path)

print(f'{steps}: saving model to {str(self.results_folder)}')

self.steps += 1
return logs
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'nuwa-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.7.1',
version = '0.7.2',
license='MIT',
description = 'NÜWA - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d28d902

Please sign in to comment.