Skip to content

Commit

Permalink
fix t2t vit having two layernorms, and make final layernorm in distil…
Browse files Browse the repository at this point in the history
…lation wrapper configurable, default to False for vit
  • Loading branch information
lucidrains committed Jun 11, 2024
1 parent 90be723 commit e3256d7
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.6.9',
version = '1.7.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
24 changes: 15 additions & 9 deletions vit_pytorch/distill.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Module
import torch.nn.functional as F

from vit_pytorch.vit import ViT
from vit_pytorch.t2t import T2TViT
from vit_pytorch.efficient import ViT as EfficientViT
Expand All @@ -12,6 +14,9 @@
def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

# classes

class DistillMixin:
Expand All @@ -20,12 +25,12 @@ def forward(self, img, distill_token = None):
x = self.to_patch_embedding(img)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]

if distilling:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)

x = self._attend(x)
Expand Down Expand Up @@ -97,15 +102,16 @@ def _attend(self, x):

# knowledge distillation wrapper

class DistillWrapper(nn.Module):
class DistillWrapper(Module):
def __init__(
self,
*,
teacher,
student,
temperature = 1.,
alpha = 0.5,
hard = False
hard = False,
mlp_layernorm = False
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
Expand All @@ -122,14 +128,14 @@ def __init__(
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))

self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim),
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
nn.Linear(dim, num_classes)
)

def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
b, *_ = img.shape
alpha = alpha if exists(alpha) else self.alpha
T = temperature if exists(temperature) else self.temperature

alpha = default(alpha, self.alpha)
T = default(temperature, self.temperature)

with torch.no_grad():
teacher_logits = self.teacher(img)
Expand Down
5 changes: 1 addition & 4 deletions vit_pytorch/t2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def __init__(self, *, image_size, num_classes, dim, depth = None, heads = None,
self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.mlp_head = nn.Linear(dim, num_classes)

def forward(self, img):
x = self.to_patch_embedding(img)
Expand Down

0 comments on commit e3256d7

Please sign in to comment.