Skip to content

Commit

Permalink
do attention in float32 always
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 14, 2022
1 parent bf1f3dc commit 1f09e65
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
17 changes: 6 additions & 11 deletions nuwa_pytorch/nuwa_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ def gumbel_sample(t, temperature = 1., dim = -1):
def safe_div(numer, denom, eps = 1e-6):
return numer / (denom + eps)

def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)

def prob_mask_like(shape, prob, device):
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

Expand Down Expand Up @@ -373,7 +368,7 @@ def forward(

# attention

attn = stable_softmax(sim, dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.talking_heads(attn)
attn = self.dropout(attn)

Expand Down Expand Up @@ -556,7 +551,7 @@ def attend(q, k, v, mask, k_bos, v_bos, kernel_size):

# attention

attn = stable_softmax(sim, dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)

attn = rearrange(attn, '(b h) ... -> b h ...', h = h)
attn = self.talking_heads(attn)
Expand Down Expand Up @@ -748,7 +743,7 @@ def forward(

# attention

attn = stable_softmax(sim, dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.talking_heads(attn)
attn = self.dropout(attn)

Expand Down Expand Up @@ -844,7 +839,7 @@ def forward(
bos_context_mask = F.pad(bos_context_mask, (1, 0), value = True)
sim_bos = sim_bos.masked_fill(~bos_context_mask, mask_value)

attn_bos = stable_softmax(sim_bos, dim = -1)
attn_bos = sim_bos.softmax(dim = -1, dtype = torch.float32)
out_bos = einsum('b h j, b h j d -> b h d', attn_bos, v_for_bos)
out_bos = rearrange(out_bos, 'b h d -> b 1 (h d)')

Expand Down Expand Up @@ -890,7 +885,7 @@ def forward(

# attention

attn = stable_softmax(sim, dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.talking_heads(attn)
attn = self.dropout(attn)

Expand Down Expand Up @@ -1050,7 +1045,7 @@ def forward(
context_mask = F.pad(context_mask, (1, 0), value = True) # null key / value
sim = sim.masked_fill(~context_mask, max_neg_value)

attn = stable_softmax(sim, dim = -1)
attn = sim.softmax(dim = -1, dtype = torch.float32)
attn = self.dropout(attn)

attn = self.talking_heads(attn)
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
name = 'nuwa-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.7.6',
version = '0.7.7',
license='MIT',
description = 'NÜWA - Pytorch',
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/nuwa-pytorch',
Expand Down

0 comments on commit 1f09e65

Please sign in to comment.