Skip to content

Commit

Permalink
[v0.1.2] Added the full model zoo and video MAE models
Browse files Browse the repository at this point in the history
  • Loading branch information
dbolya committed Jul 21, 2023
1 parent 1e3024d commit 1f825a3
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
### **[2023.07.20]** v0.1.2
- Released the full model zoo.
- Added MAE functionality to the video models.

### **[2023.06.12]** v0.1.1
- Added the ability to specify multiple pretrained checkpoints per architecture (specify with `checkpoint=<ckpt_name>`).
- Added the ability to pass `strict=False` to a pretrained model so that you can use a different number of classes. **Note:** when changing the number of classes, the head layer will be reset.
Expand Down
50 changes: 33 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,35 +63,37 @@ python setup.py build develop

## Model Zoo

Here we provide model checkpoints for Hiera. Each model listed is accessible on [torch hub](https://pytorch.org/docs/stable/hub.html), e.g.:
Here we provide model checkpoints for Hiera. Each model listed is accessible on [torch hub](https://pytorch.org/docs/stable/hub.html) even without the `hiera-transformer` package installed, e.g. the following initializes a base model pretrained and finetuned on ImageNet-1k:
```py
model = torch.hub.load("facebookresearch/hiera", model="hiera_base_224", pretrained=True, checkpoint="mae_in1k_ft_in1k")
```
For model names and corresponding checkpoint names see below.

**Note:** the speeds listed here were benchmarked _without_ PyTorch's optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). If using PyTorch 2.0 or above, your inference speed will probably be faster than what's listed here.
If you want a model with MAE pretraining only, you can replace the checkpoint with `"mae_in1k"`. Additionally, if you'd like to load the MAE decoder as well (e.g., to continue pretraining), add `mae_` the the start of the model name, e.g.:
```py
model = torch.hub.load("facebookresearch/hiera", model="mae_hiera_base_224", pretrained=True, checkpoint="mae_in1k")
```
**Note:** Our MAE models were trained with a _normalized pixel loss_. That means that the patches were normalized before the network had to predict them. If you want to visualize the predictions, you'll have to unnormalize them using the visible patches (which might work but wouldn't be perfect) or unnormalize them using the ground truth. For model more names and corresponding checkpoint names see below.

#### Coming Soon
As of now, base finetuned models are available. The rest are coming soon.

**Note:** the speeds listed here were benchmarked _without_ PyTorch's optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). If using PyTorch 2.0 or above, your inference speed will probably be faster than what's listed here.
### Image Models
| Model | Model Name | Pretrained Models<br>(IN-1K MAE) | Finetuned Models<br>(IN-1K Supervised) | IN-1K<br>Top-1 (%) | A100 fp16<br>Speed (im/s) |
|----------|-----------------------|----------------------------------|----------------------------------------|:------------------:|:-------------------------:|
| Hiera-T | `hiera_tiny_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth) | 82.8 | 2758 |
| Hiera-S | `hiera_small_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth) | 83.8 | 2211 |
| Hiera-B | `hiera_base_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth) | 84.5 | 1556 |
| Hiera-B+ | `hiera_base_plus_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth) | 85.2 | 1247 |
| Hiera-L | `hiera_large_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth) | 86.1 | 531 |
| Hiera-H | `hiera_huge_224` | Coming Soon | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth) | 86.9 | 274 |
| Hiera-T | `hiera_tiny_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth) | 82.8 | 2758 |
| Hiera-S | `hiera_small_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth) | 83.8 | 2211 |
| Hiera-B | `hiera_base_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth) | 84.5 | 1556 |
| Hiera-B+ | `hiera_base_plus_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth) | 85.2 | 1247 |
| Hiera-L | `hiera_large_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth) | 86.1 | 531 |
| Hiera-H | `hiera_huge_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth) | 86.9 | 274 |

Each model inputs a 224x224 image.
### Video Models
| Model | Model Name | Pretrained Models<br>(K400 MAE) | Finetuned Models<br>(K400) | K400 (3x5 views)<br>Top-1 (%) | A100 fp16<br>Speed (clip/s) |
|----------|--------------------------|---------------------------------|----------------------------|:-----------------------------:|:---------------------------:|
| Hiera-B | `hiera_base_16x224` | Coming Soon | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth) | 84.0 | 133.6 |
| Hiera-B+ | `hiera_base_plus_16x224` | Coming Soon | Coming Soon | 85.0 | 84.1 |
| Hiera-L | `hiera_large_16x224` | Coming Soon | Coming Soon | 87.3 | 40.8 |
| Hiera-H | `hiera_huge_16x224` | Coming Soon | Coming Soon | 87.8 | 20.9 |
| Hiera-B | `hiera_base_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth) | 84.0 | 133.6 |
| Hiera-B+ | `hiera_base_plus_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_16x224.pth) | 85.0 | 84.1 |
| Hiera-L | `hiera_large_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_large_16x224.pth) | 87.3 | 40.8 |
| Hiera-H | `hiera_huge_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_huge_16x224.pth) | 87.8 | 20.9 |

Each model inputs 16 224x224 frames with a temporal stride of 4.

Expand All @@ -103,9 +105,9 @@ This repo implements the code to run Hiera models for inference. This repository
- [x] Image Inference
- [x] MAE implementation
- [x] Video Inference
- [ ] MAE implementation
- [x] MAE implementation
- [x] Full Model Zoo
- [ ] Training scripts
- [ ] Full Model Zoo


See [examples](https://github.com/facebookresearch/hiera/tree/main/examples) for examples of how to use Hiera.
Expand All @@ -130,6 +132,20 @@ Video inference works the same way, just use a `16x224` model instead.
output, intermediates = model(x, return_intermediates=True)
```

#### MAE Inference
By default, the models do not include the MAE decoder. If you would like to use the decoder or compute MAE loss, you can instantiate an mae version by running:
```py
import hiera
model = hiera.mae_hiera_base_224(pretrained=True, checkpoint="mae_in1k")
```
Then when you run inference on the model, it will return a 4-tuple of `(loss, predictions, labels, mask)` where predictions and labels are for the _deleted tokens_ only. The returned mask will be `True` if the token is visible and `False` if it's deleted. You can change the masking ratio by passing it during inference:
```py
loss, preds, labels, mask = model(x, mask_ratio=0.6)
```
The default mask ratio is `0.6` for images, but you should pass in `0.9` for video. See the paper for details.

**Note:** We use _normalized pixel targets_ for MAE pretraining, meaning the patches are each individually normalized before the model model has to predict them. Thus, you have to unnormalize them using the ground truth before visualizing them. See `get_pixel_label_2d` in `hiera_mae.py` for details.

### Benchmarking
We provide a script for easy benchmarking. See [examples/benchmark](https://github.com/facebookresearch/hiera/blob/main/examples/benchmark.ipynb) to see how to use it.

Expand Down
5 changes: 5 additions & 0 deletions hiera/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,10 @@
mae_hiera_large_224,
mae_hiera_huge_224,

mae_hiera_base_16x224,
mae_hiera_base_plus_16x224,
mae_hiera_large_16x224,
mae_hiera_huge_16x224,

MaskedAutoencoderHiera,
)
23 changes: 20 additions & 3 deletions hiera/hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def __init__(
super().__init__()

depth = sum(stages)
self.patch_stride = patch_stride
self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)]
num_tokens = math.prod(self.tokens_spatial_shape)
flat_mu_size = math.prod(mask_unit_size)
Expand Down Expand Up @@ -438,41 +439,47 @@ def forward(

@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_tiny_224(**kwdargs):
return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), **kwdargs)


@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_small_224(**kwdargs):
return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), **kwdargs)


@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_base_224(**kwdargs):
return Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), **kwdargs)


@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_base_plus_224(**kwdargs):
return Hiera(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs)


@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_large_224(**kwdargs):
return Hiera(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs)


@pretrained_model({
"mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth",
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
}, default="mae_in1k_ft_in1k")
def hiera_huge_224(**kwdargs):
return Hiera(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs)
Expand All @@ -482,6 +489,7 @@ def hiera_huge_224(**kwdargs):

@pretrained_model({
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth",
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth",
}, default="mae_k400_ft_k400")
def hiera_base_16x224(num_classes: int = 400, **kwdargs):
return Hiera(
Expand All @@ -497,21 +505,30 @@ def hiera_base_16x224(num_classes: int = 400, **kwdargs):
)


@pretrained_model(None)
@pretrained_model({
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_16x224.pth",
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth",
}, default="mae_k400_ft_k400")
def hiera_base_plus_16x224(**kwdargs):
return hiera_base_16x224(
embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs
)


@pretrained_model(None)
@pretrained_model({
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_large_16x224.pth",
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth",
}, default="mae_k400_ft_k400")
def hiera_large_16x224(**kwdargs):
return hiera_base_16x224(
embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs
)


@pretrained_model(None)
@pretrained_model({
"mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_16x224.pth",
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth",
}, default="mae_k400_ft_k400")
def hiera_huge_16x224(**kwdargs):
return hiera_base_16x224(
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs
Expand Down
98 changes: 92 additions & 6 deletions hiera/hiera_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,28 @@ def get_pixel_label_2d(

return label

def get_pixel_label_3d(
self, input_vid: torch.Tensor, mask: torch.Tensor, norm: bool = True
) -> torch.Tensor:
# mask (boolean tensor): True must correspond to *masked*

# We use time strided loss, only take the first frame from each token
input_vid = input_vid[:, :, ::self.patch_stride[0], :, :]

size = self.pred_stride
label = input_vid.unfold(3, size, size).unfold(4, size, size)
label = label.permute(0, 2, 3, 4, 5, 6, 1) # Different from 2d, mistake during training lol
label = label.flatten(1, 3).flatten(2)
label = label[mask]

if norm:
mean = label.mean(dim=-1, keepdim=True)
var = label.var(dim=-1, keepdim=True)
label = (label - mean) / (var + 1.0e-6) ** 0.5

return label


def forward_encoder(
self, x: torch.Tensor, mask_ratio: float, mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -242,6 +264,8 @@ def forward_loss(
"""
if len(self.q_stride) == 2:
label = self.get_pixel_label_2d(x, mask)
elif len(self.q_stride) == 3:
label = self.get_pixel_label_3d(x, mask)
else:
raise NotImplementedError

Expand Down Expand Up @@ -270,43 +294,105 @@ def forward(

# Image Models

@pretrained_model(None)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth",
}, default="mae_in1k")
def mae_hiera_tiny_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), q_pool=2, **kwargs,
)


@pretrained_model(None)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth",
}, default="mae_in1k")
def mae_hiera_small_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), q_pool=2, **kwargs,
)


@pretrained_model(None)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
}, default="mae_in1k")
def mae_hiera_base_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), q_pool=2, **kwargs,
)


@pretrained_model(None)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
}, default="mae_in1k")
def mae_hiera_base_plus_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), q_pool=2, **kwargs,
)


@pretrained_model(None)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
}, default="mae_in1k")
def mae_hiera_large_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), q_pool=2, **kwargs,
)


@pretrained_model(None)
@pretrained_model({
"mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
}, default="mae_in1k")
def mae_hiera_huge_224(**kwargs):
return MaskedAutoencoderHiera(
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), q_pool=2, **kwargs,
)



# Video Models

@pretrained_model({
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth",
}, default="mae_k400")
def mae_hiera_base_16x224(num_classes: int = 400, **kwdargs):
return MaskedAutoencoderHiera(
num_classes=num_classes, # K400 has 400 classes
input_size=(16, 224, 224),
q_stride=(1, 2, 2),
mask_unit_size=(1, 8, 8),
patch_kernel=(3, 7, 7),
patch_stride=(2, 4, 4),
patch_padding=(1, 3, 3),
sep_pos_embed=True,
q_pool=2,
**kwdargs
)


@pretrained_model({
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth",
}, default="mae_k400")
@pretrained_model(None)
def mae_hiera_base_plus_16x224(**kwdargs):
return mae_hiera_base_16x224(
embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs
)


@pretrained_model({
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth",
}, default="mae_k400")
@pretrained_model(None)
def mae_hiera_large_16x224(**kwdargs):
return mae_hiera_base_16x224(
embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs
)


@pretrained_model({
"mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth",
}, default="mae_k400")
def mae_hiera_huge_16x224(**kwdargs):
return mae_hiera_base_16x224(
embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs
)
4 changes: 4 additions & 0 deletions hiera/hiera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def model_def(pretrained: bool = False, checkpoint: str = default, strict: bool

model = model_func(**kwdargs)
if pretrained:
# Disable being strict when trying to load a encoder-decoder model into an encoder-only model
if "decoder_pos_embed" in state_dict["model_state"] and not hasattr(model, "decoder_pos_embed"):
strict = False

model.load_state_dict(state_dict["model_state"], strict=strict)

return model
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name="hiera-transformer",
version="0.1.1",
version="0.1.2",
author="Chaitanya Ryali, Daniel Bolya",
url="https://github.com/facebookresearch/hiera",
description="A fast, powerful, and simple hierarchical vision transformer",
Expand Down

0 comments on commit 1f825a3

Please sign in to comment.