Skip to content

Commit

Permalink
Merge pull request #5 from zihangJiang/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
zihangJiang committed Jun 3, 2021
2 parents 2e221d2 + 0ac7c92 commit f398513
Show file tree
Hide file tree
Showing 25 changed files with 4,241 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,4 @@ dmypy.json

# Pyre type checker
.pyre/
.DS_Store
Binary file removed Pipeline.png
Binary file not shown.
74 changes: 66 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

This is a Pytorch implementation of our technical report.



![Compare](Compare.png)
![Compare](figures/Compare.png)

Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.

#### Training Pipeline
Our codes are based on the [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) by [Ross Wightman](https://github.com/rwightman).

![Pipeline](Pipeline.png)
### Update

Our codes are based on the [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) by [Ross Wightman](https://github.com/rwightman).
**2021.6: Release training code and segmentation model.**

**2021.4: Release LV-ViT models.**

#### LV-ViT Models

Expand Down Expand Up @@ -63,7 +63,65 @@ We provide NFNet-F6 generated dense label map [here](https://drive.google.com/fi

#### Training

Coming soon
Train the LV-ViT-S:

If only 4 GPUs are available,

```
CUDA_VISIBLE_DEVICES=0,1,2,3 ./distributed_train.sh 4 /path/to/imagenet --model lvvit_s -b 256 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
```

If 8 GPUs are available:
```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
```


Train the LV-ViT-M and LV-ViT-L (run on 8 GPUs):


```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_m -b 128 --apex-amp --img-size 224 --drop-path 0.2 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
```
```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_l -b 128 --lr 1.e-3 --aa rand-n3-m9-mstd0.5-inc1 --apex-amp --img-size 224 --drop-path 0.3 --token-label --token-label-data /path/to/label_data --token-label-size 14 --model-ema
```
If you want to train our LV-ViT on images with 384x384 resolution, please use `--img-size 384 --token-label-size 24`.

#### Fine-tuning

To Fine-tune the pre-trained LV-ViT-S on images with 384x384 resolution:
```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 64 --apex-amp --img-size 384 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 24 --lr 5.e-6 --min-lr 5.e-6 --weight-decay 1.e-8 --finetune /path/to/checkpoint
```

### Segmentation

Our Segmentation model are fully based upon the [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) Toolkit. The model and config files are under `seg/` folder which follow the same folder structure. You can simply drop in these file to get start.

```shell
git clone https://github.com/open-mmlab/mmsegmentation # and install

cp seg/mmseg/models/backbones/vit.py mmsegmentation/mmseg/models/backbones/
cp -r seg/configs/lvvit mmsegmentation/configs/

# test upernet+lvvit_s (add --aug-test to test on multi scale)
cd mmsegmentation
./tools/dist_test.sh configs/lvvit/upernet_lvvit_s_512x512_160k_ade20k.py /path/to/checkpoint 8 --eval mIoU [--aug-test]
```

| Backbone | Method | Crop size | Lr Schd | mIoU | mIoU(ms) | Pixel Acc.| Param |Download |
| :------------------------------ | :------ | :-------- | :------ |:------- |:--------- | :-------- | :---- | :------ |
| LV-ViT-S | UperNet | 512x512 | 160k | 47.9 | 48.6 | 83.1 | 44M |[link](https://drive.google.com/file/d/1uqNgtSnIQ-AM8tHjte1DpCawzd_1B5zI/view?usp=sharing) |
| LV-ViT-M | UperNet | 512x512 | 160k | 49.4 | 50.6 | 83.5 | 77M |[link](https://drive.google.com/file/d/1-41KTtaam2tysS-0y8Ggr8DGN5tWTyUR/view?usp=sharing) |
| LV-ViT-L | UperNet | 512x512 | 160k | 50.9 | 51.8 | 84.1 | 209M |[link](https://drive.google.com/file/d/16WWdlgSjtVqYLufT83BLhhl1NAmKENqd/view?usp=sharing) |


### Visualization

We apply the visualization method in this [repo](https://github.com/hila-chefer/Transformer-Explainability) to visualize the parts of the image that led to a certain classification for DeiT-Base and our LV-ViT-S. The parts of the image that used by the network to make the decision are highlighted in red.

![Compare](figures/Top1.jpg)

#### Reference
If you use this repo or find it useful, please consider citing:
Expand All @@ -77,4 +135,4 @@ If you use this repo or find it useful, please consider citing:
```

#### Related projects
[T2T-ViT](https://github.com/yitu-opensource/T2T-ViT/), [Re-labeling ImageNet](https://github.com/naver-ai/relabel_imagenet).
[T2T-ViT](https://github.com/yitu-opensource/T2T-ViT/), [Re-labeling ImageNet](https://github.com/naver-ai/relabel_imagenet), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation), [Transformer Explainability](https://github.com/hila-chefer/Transformer-Explainability).
5 changes: 5 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .dataset import DatasetTokenLabel, create_token_label_dataset
from .loader import create_token_label_loader
from .label_transforms_factory import create_token_label_transform
from .mixup import TokenLabelMixup, FastCollateTokenLabelMixup, mixup_target as create_token_label_target

128 changes: 128 additions & 0 deletions data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
""" Image dataset with label maps
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch.utils.data as data

import os
import re
import torch
import tarfile
import logging
from PIL import Image
_logger = logging.getLogger('token_label_dataset')

IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg']


def natural_key(string_):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]


def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx


def load_class_map(filename, root=''):
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
return class_to_idx


class DatasetTokenLabel(data.Dataset):

def __init__(
self,
root,
label_root,
load_bytes=False,
transform=None,
class_map=''):

class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(images) == 0:
raise RuntimeError(f'Found 0 images in subfolders of {root}. '
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
self.root = root
self.label_root = label_root
self.samples = images
self.imgs = self.samples # torchvision ImageFolder compat
self.class_to_idx = class_to_idx
self.load_bytes = load_bytes
self.transform = transform

def __getitem__(self, index):
path, target = self.samples[index]
score_path = os.path.join(
self.label_root,
'/'.join(path.split('/')[-2:]).split('.')[0] + '.pt')

img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
score_maps = torch.load(score_path).float()
if self.transform is not None:
img, score_maps = self.transform(img, score_maps)
# append ground truth after coords
score_maps[-1,0,0,5]=target
return img, score_maps

def __len__(self):
return len(self.samples)

def filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0]
if basename:
filename = os.path.basename(filename)
elif not absolute:
filename = os.path.relpath(filename, self.root)
return filename

def filenames(self, basename=False, absolute=False):
fn = lambda x: x
if basename:
fn = os.path.basename
elif not absolute:
fn = lambda x: os.path.relpath(x, self.root)
return [fn(x[0]) for x in self.samples]


def create_token_label_dataset(dataset_type, root, label_root):
train_dir = os.path.join(root, 'train')
if not os.path.exists(train_dir):
_logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1)
if not os.path.exists(label_root):
_logger.error('Label folder does not exist at: {}'.format(label_root))
exit(1)
return DatasetTokenLabel(train_dir, label_root)
Loading

0 comments on commit f398513

Please sign in to comment.