Skip to content

Commit

Permalink
[FUNCTIONAL]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 29, 2023
1 parent a737835 commit 9259885
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 38 deletions.
178 changes: 171 additions & 7 deletions qformer/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,180 @@
import math

import torch
from torch import nn, Tensor
from zeta import nn as znn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor, einsum, nn
from zeta import LayerNorm, default, exists, l2norm
from zeta.nn import (
MultiQueryAttention,
SimpleFeedForward,
)
from zeta.utils import enforce_types


class QFormer(nn.Module):
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim=None,
dim_head=64,
heads=8,
dropout=0.0,
norm_context=False,
cosine_sim=False,
cosine_sim_scale=16,
):
super().__init__()
self.cosine_sim = cosine_sim
self.scale = (
cosine_sim_scale if cosine_sim else (dim_head**-0.5)
)
self.heads = heads
inner_dim = dim_head * heads

context_dim = default(context_dim, dim)

self.norm = LayerNorm(dim)
self.norm_context = (
LayerNorm(context_dim) if norm_context else nn.Identity()
)
self.dropout = nn.Dropout(dropout)

self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias=False), LayerNorm(dim)
)

def forward(self, x, context, mask=None):
b, n, device = *x.shape[:2], x.device

x = self.norm(x)
context = self.norm_context(context)

q, k, v = (
self.to_q(x),
*self.to_kv(context).chunk(2, dim=-1),
)

q, k, v = map(
lambda t: rearrange(
t, "b n (h d) -> b h n d", h=self.heads
),
(q, k, v),
)

# add null key / value for classifier free guidance in prior net

nk, nv = map(
lambda t: repeat(t, "d -> b h 1 d", h=self.heads, b=b),
self.null_kv.unbind(dim=-2),
)

k = torch.cat((nk, k), dim=-2)
v = torch.cat((nv, v), dim=-2)

if self.cosine_sim:
q, k = map(l2norm, (q, k))

q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))

sim = einsum("b h i d, b h j d -> b h i j", q, k)
max_neg_value = -torch.finfo(sim.dtype).max

if exists(mask):
mask = F.pad(mask, (1, 0), value=True)
mask = rearrange(mask, "b j -> b 1 1 j")
sim = sim.masked_fill(~mask, max_neg_value)

attn = sim.softmax(dim=-1, dtype=torch.float32)
attn = attn.type(sim.dtype)

out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)


class ImgBlock(nn.Module):
"""
ImgBlock is a module that performs multi-query attention, cross-attention, and feedforward operations on input tensors.
Args:
dim (int): The dimension of the input tensors.
depth (int): The number of times the operations are applied.
heads (int): The number of attention heads.
dropout (float, optional): The dropout probability. Defaults to 0.1.
emb_dropout (float, optional): The embedding dropout probability. Defaults to 0.1.
Attributes:
dim (int): The dimension of the input tensors.
depth (int): The number of times the operations are applied.
heads (int): The number of attention heads.
dropout (float): The dropout probability.
emb_dropout (float): The embedding dropout probability.
attn (MultiQueryAttention): The multi-query attention module.
cross_attn (CrossAttention): The cross-attention module.
feedforward (SimpleFeedForward): The feedforward module.
Methods:
forward(x: Tensor, img: Tensor) -> Tensor:
Performs the forward pass of the ImgBlock module.
"""

@enforce_types
def __init__(
self, dim, depth, heads, mlp_dim, dropout=0.1, emb_dropout=0.1
self,
dim: int,
depth: int,
heads: int,
dropout: float = 0.1,
*args,
**kwargs,
):
pass
super(ImgBlock, self).__init__(*args, **kwargs)
self.dim = dim
self.depth = depth
self.heads = heads
self.dropout = dropout

self.attn = MultiQueryAttention(dim, heads)

self.cross_attn = CrossAttention(
dim=dim,
heads=heads,
dropout=dropout,
)

self.feedforward = SimpleFeedForward(dim, dim * 4, dropout)

@enforce_types
def forward(self, x: Tensor) -> Tensor:
pass
def forward(self, x: Tensor, img: Tensor) -> Tensor:
"""
Performs the forward pass of the ImgBlock module.
Args:
x (Tensor): The input tensor.
img (Tensor): The image tensor.
Returns:
Tensor: The output tensor after applying multi-query attention, cross-attention, and feedforward operations.
"""
for i in range(self.depth):
attended, _, _ = self.attn(x)
crossed = self.cross_attn(attended, img)
feedforwarded = self.feedforward(crossed)
return feedforwarded


# 3d tensor, B x SEQLEN x DIM
x = torch.randn(1, 32, 1024)
image = torch.randn(1, 32, 1024)

attn = ImgBlock(1024, 8, 1024)
out = attn(x, image)
print(out.shape)
59 changes: 28 additions & 31 deletions qformer/masking.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,43 @@
import torch
from torch import nn

def bi_directional_self_attn_mask(img_tokens, text_tokens):

def multi_modal_causal_self_attention_mask(x):
"""
Creates a bi-directional self-attention mask for image-text matching tasks.
All image and text tokens can attend to each other.
Applies a multi-modal causal self-attention mask. This mask allows query tokens to attend to all
other query tokens and text tokens to attend only to preceding text tokens and all query tokens.
Args:
img_tokens (torch.Tensor): The tensor representing image tokens with shape [B, C, H, W].
text_tokens (torch.Tensor): The tensor representing text tokens with shape [B, SEQLEN, Dim].
- x (torch.Tensor): the input tensor of shape [batch_size, seqlen, dim]
Returns:
torch.Tensor: A mask tensor where all elements are zero (allowing full attention).
- torch.Tensor: the mask tensor of shape [batch_size, seqlen, seqlen] with 0s where attention is allowed
and float('-inf') where it is not, suitable for adding to the raw attention scores.
"""
batch_size, seq_len, _ = text_tokens.size()
num_image_tokens = img_tokens.size(2) * img_tokens.size(3)
total_seq_len = seq_len + num_image_tokens
mask = torch.zeros((batch_size, total_seq_len, total_seq_len), dtype=text_tokens.dtype, device=text_tokens.device)
return mask
batch_size, seqlen, _ = x.shape
# Initialize mask to all ones
mask = torch.ones((seqlen, seqlen), dtype=torch.float32)
# Create a causal mask for the text tokens
causal_mask = torch.tril(
torch.ones((seqlen // 2, seqlen // 2), dtype=torch.float32)
)
mask[-(seqlen // 2) :, -(seqlen // 2) :] = causal_mask
# Invert the mask so that 0s are where attention is allowed and float('-inf') where it is not
mask = torch.log(mask)

# Expand the mask for the batch size
mask = mask.repeat(batch_size, 1, 1)

def mmc_self_attn_mask(img, text, *args):
total_tokens = img + text
mask = torch.full(
(total_tokens, total_tokens, *args), float("-inf")
)
mask[:img, :img] = 0
mask[:img:, :img] = 0
mask[:img:, img:] = torch.tril(
torch.zeros((text, text, *args))
)
return mask


def uni_modal_self_attn_mask(img, text):
total = img + text
mask = torch.full(total, total), float("-inf")
mask[:img, :img] = 0
mask[img:, img:] = 0
return
batch_size = 2
seqlen = 8
dim = 512

x = torch.randn(1, 3, 224, 224)
y = torch.randn(1, 10, 768)

print(bi_directional_self_attn_mask(x, y).shape)
# Example to test the function with dummy data
x_dummy = torch.rand(batch_size, seqlen, dim) # Dummy data
multi_modal_causal_mask = multi_modal_causal_self_attention_mask(
x_dummy
)
print(multi_modal_causal_mask.shape)

0 comments on commit 9259885

Please sign in to comment.