Skip to content

Commit

Permalink
vit + vit_mae are working
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed Apr 30, 2024
1 parent 547f6c4 commit 602913e
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 23 deletions.
10 changes: 7 additions & 3 deletions src/transformers/models/vit_mae/modeling_vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,14 @@ def random_masking(self, sequence, noise=None):
ids_restore = torch.argsort(ids_shuffle, dim=1)

# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
ids_keep = ids_shuffle[:, :len_keep].to(sequence.device)
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))

# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([batch_size, seq_length], device=sequence.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
mask = torch.gather(mask, dim=1, index=ids_restore.to(sequence.device))

return sequence_unmasked, mask, ids_restore

Expand Down Expand Up @@ -813,7 +813,11 @@ def forward(
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device),
) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token

# add pos embed
Expand Down
1 change: 1 addition & 0 deletions tests/models/vit/test_modeling_flax_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def prepare_config_and_inputs_for_common(self):
@require_flax
class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxViTModel, FlaxViTForImageClassification) if is_flax_available() else ()
has_attentions = False

def setUp(self) -> None:
self.model_tester = FlaxViTModelTester(self)
Expand Down
1 change: 1 addition & 0 deletions tests/models/vit/test_modeling_tf_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class TFViTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
has_attentions = False

def setUp(self):
self.model_tester = TFViTModelTester(self)
Expand Down
4 changes: 4 additions & 0 deletions tests/models/vit/test_modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
initializer_range=0.02,
scope=None,
encoder_stride=2,
mask_ratio=0.5,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -91,6 +92,9 @@ def __init__(
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
Expand Down
5 changes: 4 additions & 1 deletion tests/models/vit_mae/test_modeling_vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
mask_ratio=0.6,
scope=None,
mask_ratio=0.5,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -89,6 +89,9 @@ def __init__(
# (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
Expand Down
25 changes: 6 additions & 19 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3717,40 +3717,27 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
processed_inputs["attention_mask"] = dummy_attention_mask

if (
"bool_masked_pos"
in inspect.signature(model_eager.forward).parameters
"bool_masked_pos" in inspect.signature(model_eager.forward).parameters
) and not deactivate_mask:
dummy_mask = torch.ones(
(self.model_tester.num_masks,)
)
dummy_mask = torch.ones((self.model_tester.num_masks,))

# In case of additional token (like class) we define a custome `mask_length`
if hasattr(self.model_tester, "mask_length"):
dummy_mask = torch.cat(
[
dummy_mask,
torch.zeros(
self.model_tester.mask_length
- dummy_mask.size(0)
),
torch.zeros(self.model_tester.mask_length - dummy_mask.size(0)),
]
)
else:
dummy_mask = torch.cat(
[
dummy_mask,
torch.zeros(
self.model_tester.seq_length
- dummy_mask.size(0)
),
torch.zeros(self.model_tester.seq_length - dummy_mask.size(0)),
]
)
dummy_bool_masked_pos = dummy_mask.expand(
batch_size, -1
).bool()
processed_inputs["bool_masked_pos"] = (
dummy_bool_masked_pos.to(torch_device)
)
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)

if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
Expand Down

0 comments on commit 602913e

Please sign in to comment.