Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 24, 2022
1 parent 649a916 commit 1260bf7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
30 changes: 14 additions & 16 deletions nuwa_pytorch/nuwa_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,19 +446,19 @@ def __init__(

# precalculate causal mask

indices = torch.arange(max_num_tokens)
shaped_indices = rearrange(indices, '(f h w) -> 1 1 f h w', f = max_frames, h = fmap_size, w = fmap_size)
padded_indices = F.pad(shaped_indices, self.video_padding, value = max_num_tokens) # padding has value of max tokens so to be masked out
unfolded_indices = unfoldNd(padded_indices, kernel_size = self.kernel_size, dilation = self.dilation)
unfolded_indices = rearrange(unfolded_indices, '1 k n -> n k')
ones = torch.ones((max_num_tokens,))
ones = rearrange(ones, '(f h w) -> 1 1 f h w', f = max_frames, h = fmap_size, w = fmap_size)
ones = F.pad(ones, self.video_padding, value = 0.)
ones = unfoldNd(ones, kernel_size = self.kernel_size, dilation = self.dilation)
ones = rearrange(ones, '1 k n -> n k')

# mask out padding

mask = unfolded_indices == max_num_tokens
padding_mask = ones == 0.

# bos tokens never get masked out

mask = F.pad(mask, (1, 0), value = False)
mask = F.pad(padding_mask, (1, 0), value = False)
self.register_buffer('mask', mask)

def forward(self, x, **kwargs):
Expand Down Expand Up @@ -659,18 +659,16 @@ def get_mask(self, t):
return self.mask

device, seq_len = t.device, t.shape[-3] * self.height
q_range = torch.arange(seq_len, device = device, dtype = torch.float32)
k_range = torch.arange(seq_len, device = device, dtype = torch.float32)

q_range = rearrange(q_range, '(n m) -> n m', m = self.height)
k_range = rearrange(k_range, '(n m) -> 1 1 n m', m = self.height)
ones = torch.ones((seq_len,), device = device)
ones = rearrange(ones, '(n m) -> 1 1 n m', m = self.height)

k_range = F.pad(k_range, self.causal_padding, value = seq_len)
k_range = unfoldNd(k_range, kernel_size = self.kernel_size, dilation = self.dilation)
k_range = rearrange(k_range, '1 d n -> n d')
ones = F.pad(ones, self.causal_padding, value = 0.)
ones = unfoldNd(ones, kernel_size = self.kernel_size, dilation = self.dilation)
ones = rearrange(ones, '1 d n -> n d')

mask = rearrange(k_range, 'n j -> n 1 j') == seq_len
mask = F.pad(mask, (1, 0), value = False)
padding_mask = rearrange(ones, 'n j -> n 1 j') == 0.
mask = F.pad(padding_mask, (1, 0), value = False)

self.register_buffer('mask', mask, persistent = False)
return mask
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'nuwa-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.7.0',
version = '0.7.1',
license='MIT',
description = 'NÜWA - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 1260bf7

Please sign in to comment.