From 649a91677ce097e81d1f41b98f64bf11aa336bd8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 23 Mar 2022 21:07:12 -0700 Subject: [PATCH] move to using causal convolutions through asymmetric padding --- nuwa_pytorch/nuwa_pytorch.py | 46 ++++++++++++++++++++---------------- setup.py | 2 +- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/nuwa_pytorch/nuwa_pytorch.py b/nuwa_pytorch/nuwa_pytorch.py index 37e321c..45ba90f 100644 --- a/nuwa_pytorch/nuwa_pytorch.py +++ b/nuwa_pytorch/nuwa_pytorch.py @@ -426,7 +426,12 @@ def __init__( self.padding_height = calc_same_padding(self.kernel_size[1], self.dilation[1]) self.padding_width = calc_same_padding(self.kernel_size[2], self.dilation[2]) - self.video_padding = (self.padding_width, self.padding_width, self.padding_height, self.padding_height, self.padding_frame, self.padding_frame) + # use separate padding for causal convolution vs non-causal + + if causal: + self.video_padding = (self.padding_width * 2, 0, self.padding_height * 2, 0, self.padding_frame * 2, 0) + else: + self.video_padding = (self.padding_width, self.padding_width, self.padding_height, self.padding_height, self.padding_frame, self.padding_frame) # save video shape and calculate max number of tokens @@ -447,15 +452,13 @@ def __init__( unfolded_indices = unfoldNd(padded_indices, kernel_size = self.kernel_size, dilation = self.dilation) unfolded_indices = rearrange(unfolded_indices, '1 k n -> n k') - # if causal, compare query and key indices and make sure past cannot see future - # if not causal, just mask out the padding + # mask out padding - if causal: - mask = rearrange(indices, 'n -> n 1') < unfolded_indices - else: - mask = unfolded_indices == max_num_tokens + mask = unfolded_indices == max_num_tokens + + # bos tokens never get masked out - mask = F.pad(mask, (1, 0), value = False) # bos tokens never get masked out + mask = F.pad(mask, (1, 0), value = False) self.register_buffer('mask', mask) def forward(self, x, **kwargs): @@ -643,17 +646,17 @@ def __init__( self.kernel_size = (kernel_size, height) self.dilation = (dilation, 1) - self.padding = (calc_same_padding(kernel_size, dilation), 0) + self.causal_padding = (0, 0, calc_same_padding(kernel_size, dilation) * 2, 0) self.rel_pos_bias = AxialPositionalEmbedding(heads, shape = self.kernel_size) if exists(rel_pos_bias) else None # causal mask - self.register_buffer('causal_mask', None, persistent = False) + self.register_buffer('mask', None, persistent = False) - def get_causal_mask(self, t): - if exists(self.causal_mask) and self.causal_mask.shape[-3] == t.shape[-3]: - return self.causal_mask + def get_mask(self, t): + if exists(self.mask) and self.mask.shape[-3] == t.shape[-3]: + return self.mask device, seq_len = t.device, t.shape[-3] * self.height q_range = torch.arange(seq_len, device = device, dtype = torch.float32) @@ -662,15 +665,15 @@ def get_causal_mask(self, t): q_range = rearrange(q_range, '(n m) -> n m', m = self.height) k_range = rearrange(k_range, '(n m) -> 1 1 n m', m = self.height) - k_range = F.pad(k_range, (0, 0, self.padding[0], self.padding[0]), value = seq_len) + k_range = F.pad(k_range, self.causal_padding, value = seq_len) k_range = unfoldNd(k_range, kernel_size = self.kernel_size, dilation = self.dilation) k_range = rearrange(k_range, '1 d n -> n d') - causal_mask = rearrange(q_range, 'n i -> n i 1') < rearrange(k_range, 'n j -> n 1 j') - causal_mask = F.pad(causal_mask, (1, 0), value = False) + mask = rearrange(k_range, 'n j -> n 1 j') == seq_len + mask = F.pad(mask, (1, 0), value = False) - self.register_buffer('causal_mask', causal_mask, persistent = False) - return causal_mask + self.register_buffer('mask', mask, persistent = False) + return mask def forward( self, @@ -717,7 +720,8 @@ def forward( # reshape key / values to be unfolded k, v = map(lambda t: rearrange(t, 'b h (x y) d -> (b h) d x y ', y = tokens_per_timestep), (k, v)) - k, v = map(lambda t: F.unfold(t, kernel_size = self.kernel_size, dilation = self.dilation, padding = self.padding), (k, v)) + k, v = map(lambda t: F.pad(t, self.causal_padding), (k, v)) + k, v = map(lambda t: F.unfold(t, kernel_size = self.kernel_size, dilation = self.dilation), (k, v)) k, v = map(lambda t: rearrange(t, '(b h f) (d j) i -> b h i (f j) d', b = b, h = h, j = kernel_numel), (k, v)) # add bos @@ -741,8 +745,8 @@ def forward( # causal + padding mask mask_value = -torch.finfo(x.dtype).max - causal_mask = self.get_causal_mask(sim) - sim = sim.masked_fill(causal_mask, mask_value) + mask = self.get_mask(sim) + sim = sim.masked_fill(mask, mask_value) # attention diff --git a/setup.py b/setup.py index 1a70d40..f8fc6cc 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'nuwa-pytorch', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.6.4', + version = '0.7.0', license='MIT', description = 'NÜWA - Pytorch', author = 'Phil Wang',