Skip to content

Commit

Permalink
move to using causal convolutions through asymmetric padding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 24, 2022
1 parent 82f306c commit 649a916
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
46 changes: 25 additions & 21 deletions nuwa_pytorch/nuwa_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,12 @@ def __init__(
self.padding_height = calc_same_padding(self.kernel_size[1], self.dilation[1])
self.padding_width = calc_same_padding(self.kernel_size[2], self.dilation[2])

self.video_padding = (self.padding_width, self.padding_width, self.padding_height, self.padding_height, self.padding_frame, self.padding_frame)
# use separate padding for causal convolution vs non-causal

if causal:
self.video_padding = (self.padding_width * 2, 0, self.padding_height * 2, 0, self.padding_frame * 2, 0)
else:
self.video_padding = (self.padding_width, self.padding_width, self.padding_height, self.padding_height, self.padding_frame, self.padding_frame)

# save video shape and calculate max number of tokens

Expand All @@ -447,15 +452,13 @@ def __init__(
unfolded_indices = unfoldNd(padded_indices, kernel_size = self.kernel_size, dilation = self.dilation)
unfolded_indices = rearrange(unfolded_indices, '1 k n -> n k')

# if causal, compare query and key indices and make sure past cannot see future
# if not causal, just mask out the padding
# mask out padding

if causal:
mask = rearrange(indices, 'n -> n 1') < unfolded_indices
else:
mask = unfolded_indices == max_num_tokens
mask = unfolded_indices == max_num_tokens

# bos tokens never get masked out

mask = F.pad(mask, (1, 0), value = False) # bos tokens never get masked out
mask = F.pad(mask, (1, 0), value = False)
self.register_buffer('mask', mask)

def forward(self, x, **kwargs):
Expand Down Expand Up @@ -643,17 +646,17 @@ def __init__(

self.kernel_size = (kernel_size, height)
self.dilation = (dilation, 1)
self.padding = (calc_same_padding(kernel_size, dilation), 0)

self.causal_padding = (0, 0, calc_same_padding(kernel_size, dilation) * 2, 0)
self.rel_pos_bias = AxialPositionalEmbedding(heads, shape = self.kernel_size) if exists(rel_pos_bias) else None

# causal mask

self.register_buffer('causal_mask', None, persistent = False)
self.register_buffer('mask', None, persistent = False)

def get_causal_mask(self, t):
if exists(self.causal_mask) and self.causal_mask.shape[-3] == t.shape[-3]:
return self.causal_mask
def get_mask(self, t):
if exists(self.mask) and self.mask.shape[-3] == t.shape[-3]:
return self.mask

device, seq_len = t.device, t.shape[-3] * self.height
q_range = torch.arange(seq_len, device = device, dtype = torch.float32)
Expand All @@ -662,15 +665,15 @@ def get_causal_mask(self, t):
q_range = rearrange(q_range, '(n m) -> n m', m = self.height)
k_range = rearrange(k_range, '(n m) -> 1 1 n m', m = self.height)

k_range = F.pad(k_range, (0, 0, self.padding[0], self.padding[0]), value = seq_len)
k_range = F.pad(k_range, self.causal_padding, value = seq_len)
k_range = unfoldNd(k_range, kernel_size = self.kernel_size, dilation = self.dilation)
k_range = rearrange(k_range, '1 d n -> n d')

causal_mask = rearrange(q_range, 'n i -> n i 1') < rearrange(k_range, 'n j -> n 1 j')
causal_mask = F.pad(causal_mask, (1, 0), value = False)
mask = rearrange(k_range, 'n j -> n 1 j') == seq_len
mask = F.pad(mask, (1, 0), value = False)

self.register_buffer('causal_mask', causal_mask, persistent = False)
return causal_mask
self.register_buffer('mask', mask, persistent = False)
return mask

def forward(
self,
Expand Down Expand Up @@ -717,7 +720,8 @@ def forward(
# reshape key / values to be unfolded

k, v = map(lambda t: rearrange(t, 'b h (x y) d -> (b h) d x y ', y = tokens_per_timestep), (k, v))
k, v = map(lambda t: F.unfold(t, kernel_size = self.kernel_size, dilation = self.dilation, padding = self.padding), (k, v))
k, v = map(lambda t: F.pad(t, self.causal_padding), (k, v))
k, v = map(lambda t: F.unfold(t, kernel_size = self.kernel_size, dilation = self.dilation), (k, v))
k, v = map(lambda t: rearrange(t, '(b h f) (d j) i -> b h i (f j) d', b = b, h = h, j = kernel_numel), (k, v))

# add bos
Expand All @@ -741,8 +745,8 @@ def forward(
# causal + padding mask

mask_value = -torch.finfo(x.dtype).max
causal_mask = self.get_causal_mask(sim)
sim = sim.masked_fill(causal_mask, mask_value)
mask = self.get_mask(sim)
sim = sim.masked_fill(mask, mask_value)

# attention

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'nuwa-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.6.4',
version = '0.7.0',
license='MIT',
description = 'NÜWA - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 649a916

Please sign in to comment.