Skip to content
/ SAIM Public

Official PyTorch Implementation of Exploring Stochastic Autoregressive Image Modeling for Visual Representation, Accepted by AAAI 2023.

License

Notifications You must be signed in to change notification settings

qiy20/SAIM

Repository files navigation

SAIM

Official PyTorch Implementation of Exploring Stochastic Autoregressive Image Modeling for Visual Representation, Accepted by AAAI 2023.

Introduction

Pipeline

SAIM is a novel self-supervised pre-training framework that performs autoregressive image modeling with stochastic permutation strategy. Our method significantly improves the performance of autoregressive image modeling and achieves the best accuracy (83.9%) on the vanilla ViT-Base model among methods using only ImageNet-1K data.

Main Results on ImageNet-1k

The following table provides pretrained checkpoints and logs used in the paper.

SAIM-Base
pretrained checkpoints download
logs download

The results of Finetune and Linear probing on ImageNet-1k are as following:

Models Architecture Pretrain Epochs FT acc@1(%) LIN acc@1(%) FT logs/weights LIN logs/weights
BEiT ViT-B 800 83.2 37.6 - -
MAE ViT-B 1600 83.6 67.8 - -
SimMIM ViT-B 1600 83.8 56.7 - -
iGPT iGPT-L - 72.6 65.2 - -
ViT-iGPT ViT-B 300 82.7 20.4 - -
SAIM ViT-B 300 83.6 58.5 - -
SAIM ViT-B 800 83.9 62.5 log/weight log/weight

Getting Started

Install

  • Clone this repo:
git clone https://github.com/qiy20/SAIM
cd SAIM
  • Create a conda environment and activate it:
conda create -n saim python=3.9
conda activate saim
  • Install Pytorch==1.13.0 and torchvision==0.14.0 with CUDA==11.6
conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
  • Install timm==0.4.5
pip install timm==0.4.5

Data preparation

You can download the ImageNet-1K here and prepare the ImageNet-1K follow this format:

imagenet
  ├── train
  │   ├── class1
  │   │   ├── img1.jpeg
  │   │   ├── img2.jpeg
  │   │   └── ...
  │   ├── class2
  │   │   ├── img3.jpeg
  │   │   └── ...
  │   └── ...
  └── val
      ├── class1
      │   ├── img4.jpeg
      │   ├── img5.jpeg
      │   └── ...
      ├── class2
      │   ├── img6.jpeg
      │   └── ...
      └── ...

Pretrain

python -m torch.distributed.launch --nproc_per_node 32 main_pretrain.py \
    --batch_size 64 --epochs 800 \
    --model saim_base --query_depth 12 --prediction_head_type MLP \
    --gaussian_kernel_size 9 --gaussian_sigma 1 --norm_pix_loss \
    --blr 2e-4 --warmup_epochs 30 --weight_decay 0.5 \
    --data_path <imagenet-path> --output_dir <output-directory>

Finetune

python -m torch.distributed.launch --nproc_per_node 32 main_finetune.py \
    --model vit_base_patch16 --cls_token --batch_size 32 \
    --blr 5e-4 --layer_decay 0.65 --epochs 100 --warmup_epochs 20 \
    --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
    --dist_eval --data_path <imagenet-path> \
    --finetune <pretrained-ckpt> --output_dir <output-directory>

Linear Probing

python -m torch.distributed.launch --nproc_per_node 32 main_linprobe.py \
    --model vit_base_patch16 --cls_token --batch_size 64 \
    --blr 0.1 --epochs 90 --warmup_epochs 0 --weight_decay 0.0 \
    --dist_eval --data_path <imagenet-path> \
    --finetune <pretrained-ckpt> --output_dir <output-directory>

Visualization

SAIM-attention_v11

We show example results for ImageNet validation set. Description of images from left to right: (a) the original image, (b) the attention map of ViT-iGPT, (c) the attention map of SAIM. SAIM focuses on the main information of the image, and obtains human-level attention representation with unlabeled data.

Acknowledgement

The pretraining and finetuning of our project are based on DeiT , BEiT and MAE.

LICENSE

SAIM is released under the MIT License.

Citation

@inproceedings{qi2023exploring,
  title={Exploring Stochastic Autoregressive Image Modeling for Visual Representation},
  author={Qi, Yu and Yang, Fan and Zhu, Yousong and Liu, Yufei and Wu, Liwei and Zhao, Rui and Li, Wei},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={37},
  number={2},
  pages={2074--2081},
  year={2023}
}

About

Official PyTorch Implementation of Exploring Stochastic Autoregressive Image Modeling for Visual Representation, Accepted by AAAI 2023.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages