Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

masking compatible with fullgraph compile #91

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions examples/autoencoder_lfq_masked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# FashionMnist VQ experiment with various settings.
# From https://github.com/minyoungg/vqtorch/blob/main/examples/autoencoder.py

from tqdm.auto import trange
from math import log2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from einops import rearrange

from vector_quantize_pytorch import LFQ

lr = 5e-4
batch_size = 256
train_iter = 1000
seed = 1234
codebook_size = 2 ** 8
# 32 codes per image
num_codebooks = 32
entropy_loss_weight = 0.01
commitment_loss_weight = 0.25
diversity_gamma = 1.
device = "cuda" if torch.cuda.is_available() else "cpu"

class LFQAutoEncoder(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
**vq_kwargs
):
super().__init__()

self.encode = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
)

self.lfq = LFQ(**vq_kwargs)

self.decode = nn.Sequential(
nn.Linear(hidden_dim, input_dim),
)
return

def forward(self, x, mask=None):
x = self.encode(x)
x, indices, entropy_aux_loss = self.lfq(x, mask=mask)
x = self.decode(x)
return x, indices, entropy_aux_loss


def train(model, train_loader, train_iterations=1000, add_masked_data=False):
def iterate_dataset(data_loader):
data_iter = iter(data_loader)
while True:
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(data_loader)
x, y = next(data_iter)
yield x.to(device), y.to(device)

for _ in (pbar := trange(train_iterations)):
opt.zero_grad()
x, _ = next(iterate_dataset(train_loader))

og_shape = x.shape
x = rearrange(x, 'b c h w -> b 1 (c h w)')

mask = torch.ones(x.shape[0], 2 if add_masked_data else 1, dtype=torch.bool, device=x.device)
if add_masked_data:
masked_data = torch.randn_like(x)
x = torch.concat([x,masked_data], dim=1)
# Mask where masked_data is False
mask[:,1] = False

out, indices, entropy_aux_loss = model(x, mask=mask)

rec_loss = F.l1_loss(out[mask], x[mask])
(rec_loss + entropy_aux_loss).backward()

opt.step()
pbar.set_description(
f"rec loss: {rec_loss.item():.3f} | "
+ f"entropy aux loss: {entropy_aux_loss.item():.3f} | "
+ f"active %: {indices[mask].unique().numel() / codebook_size * 100:.3f}"
)
return

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

train_dataset = DataLoader(
datasets.FashionMNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=batch_size,
shuffle=True,
)


torch.random.manual_seed(seed)

mnist_h, mnist_w = 28, 28
mnist_c = 1
input_dim = mnist_h * mnist_w * mnist_c
# this is also the number of codes
hidden_dim = codebook_size

def get_model_and_opt():
model = LFQAutoEncoder(
input_dim,
hidden_dim,
dim=hidden_dim,
codebook_size = codebook_size,
entropy_loss_weight = entropy_loss_weight,
diversity_gamma = diversity_gamma,
commitment_loss_weight=commitment_loss_weight,
num_codebooks=num_codebooks,
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=lr)
return model, opt

print("baseline")
model, opt = get_model_and_opt()
train(model, train_dataset, train_iterations=train_iter)

print("with masking")
model, opt = get_model_and_opt()
train(model, train_dataset, train_iterations=train_iter, add_masked_data=True)
53 changes: 44 additions & 9 deletions vector_quantize_pytorch/lookup_free_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,38 @@ def pack_one(t, pattern):
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]


# masked mean

def mult_along_first_dims(x, y):
# returns x * y elementwise along the first dims of x and y
ndim_to_expand = x.ndim - y.ndim
assert ndim_to_expand >= 0
for _ in range(ndim_to_expand):
y = y.unsqueeze(-1)
return x * y

def masked_mean(x, m):
"""
Takes the mean of the elements of x that are not masked across the first
shared dims of x and m.

Equivalent to: x[m].mean(dim=list(range(m.ndim)))

m is False where padding is
"""

# masks x
x = mult_along_first_dims(x, m)

# divides by the number of non masked items
x = x / m.sum()

# sum across the leading dims that x and m share
return x.sum(dim=list(range(m.ndim)))



# entropy

def log(t, eps = 1e-5):
Expand Down Expand Up @@ -215,16 +247,19 @@ def forward(

prob = (-distance * inv_temperature).softmax(dim = -1)

per_sample_entropy = entropy(prob).mean()
if exists(mask):
# b n c d -> 1
per_sample_entropy = masked_mean(entropy(prob), mask).mean()

# account for mask
# b n c d -> c d
avg_prob = masked_mean(prob, mask)
else:
per_sample_entropy = entropy(prob).mean()

if exists(mask):
prob = prob[mask]
# distribution over all available tokens in the batch

# distribution over all available tokens in the batch
avg_prob = reduce(prob, '... c d -> c d', 'mean')

avg_prob = reduce(prob, '... c d -> c d', 'mean')
codebook_entropy = entropy(avg_prob).mean()

# 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
Expand All @@ -241,9 +276,9 @@ def forward(
commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')

if exists(mask):
commit_loss = commit_loss[mask]

commit_loss = commit_loss.mean()
commit_loss = masked_mean(commit_loss, mask).mean()
else:
commit_loss = commit_loss.mean()
else:
commit_loss = self.zero

Expand Down