Skip to content

Commit

Permalink
cleanup conv-like attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 24, 2022
1 parent 3064403 commit 2969103
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions dalle_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,18 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
dots_image = einsum('b i d, b i j d -> b i j', q_img, k_img)
dots_image_to_text = einsum('b i d, b j d -> b i j', q_img, k_text)

# calculate causal attention for local convolution
# use padding of 0 on tensor of 1s and unfold for padding mask

i, j = dots_image.shape[-2:]
img_seq = torch.arange(img_seq_len, device = device)
k_img_indices = rearrange(img_seq.float(), '(h w) -> () () h w', h = img_size)
k_img_indices = F.pad(k_img_indices, causal_padding, value = img_seq_len) # padding set to be max, so it is never attended to
k_img_indices = F.unfold(k_img_indices, kernel_size, dilation = dilation)
k_img_indices = rearrange(k_img_indices, 'b j i -> b i j')
ones = torch.ones((img_seq_len,), device = device)
ones = rearrange(ones, '(h w) -> () () h w', h = img_size)
ones = F.pad(ones, causal_padding, value = 0.)
ones = F.unfold(ones, kernel_size, dilation = dilation)
ones = rearrange(ones, 'b j i -> b i j')

# mask image attention

padding_mask = k_img_indices == img_seq_len
padding_mask = ones == 0.

# concat text mask with image causal mask

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.5.0',
version = '1.5.1',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2969103

Please sign in to comment.