Skip to content

Commit

Permalink
rotary needs to be done with full precision to be safe
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 11, 2024
1 parent bca88e9 commit 90be723
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.6.8',
version = '1.6.9',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/rvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.cuda.amp import autocast

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# rotary embeddings

@autocast(enabled = False)
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
Expand All @@ -22,6 +24,7 @@ def __init__(self, dim, max_freq = 10):
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
self.register_buffer('scales', scales)

@autocast(enabled = False)
def forward(self, x):
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))

Expand Down

0 comments on commit 90be723

Please sign in to comment.