Skip to content

Commit

Permalink
update fine-tuning instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
notjzh committed Jun 16, 2021
1 parent 73cf498 commit aa438ef
Show file tree
Hide file tree
Showing 20 changed files with 20 additions and 15 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Our codes are based on the [pytorch-image-models](https://github.com/rwightman/p

### Update

**2021.6: Support `pip install tlmm` to use our Token Labeling for image models.**
**2021.6: Support `pip install tlt` to use our Token Labeling Toolbox for image models.**

**2021.6: Release training code and segmentation model.**

Expand Down Expand Up @@ -99,6 +99,11 @@ To Fine-tune the pre-trained LV-ViT-S on images with 384x384 resolution:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model lvvit_s -b 64 --apex-amp --img-size 384 --drop-path 0.1 --token-label --token-label-data /path/to/label_data --token-label-size 24 --lr 5.e-6 --min-lr 5.e-6 --weight-decay 1.e-8 --finetune /path/to/checkpoint
```

To Fine-tune the pre-trained LV-ViT-S on other datasets without token labeling:
```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/dataset --model lvvit_s -b 64 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes $NUM_CLASSES --finetune /path/to/checkpoint
```

### Segmentation

Our Segmentation model are fully based upon the [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) Toolkit. The model and config files are under `seg/` folder which follow the same folder structure. You can simply drop in these file to get start.
Expand Down
8 changes: 4 additions & 4 deletions flops_computation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import tlmm.models
import tlt.models
# summary of model flops and parameters

model_list = [tlmm.models.lvvit_s,
tlmm.models.lvvit_m,
tlmm.models.lvvit_l]
model_list = [tlt.models.lvvit_s,
tlt.models.lvvit_m,
tlt.models.lvvit_l]

img_size_list=[224,288,384,448]

Expand Down
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler

import tlmm.models
from tlmm.data import create_token_label_target, TokenLabelMixup, FastCollateTokenLabelMixup, create_token_label_loader, create_token_label_dataset
from tlmm.loss import TokenLabelCrossEntropy, TokenLabelSoftTargetCrossEntropy
from tlmm.utils import load_for_transfer_learning
import tlt.models
from tlt.data import create_token_label_target, TokenLabelMixup, FastCollateTokenLabelMixup, create_token_label_loader, create_token_label_dataset
from tlt.loss import TokenLabelCrossEntropy, TokenLabelSoftTargetCrossEntropy
from tlt.utils import load_pretrained_weights


try:
Expand Down Expand Up @@ -359,7 +359,7 @@ def main():
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly

if args.finetune:
load_for_transfer_learning(model=model,checkpoint_path=args.finetune,use_ema=args.model_ema, strict=False, num_classes=args.num_classes)
load_pretrained_weights(model=model,checkpoint_path=args.finetune,use_ema=args.model_ema, strict=False, num_classes=args.num_classes)

if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from setuptools import setup, find_packages

setup(
name = 'tlmm',
name = 'tlt',
packages = find_packages(exclude=['seg','visualize']),
version = '0.1.0',
license='Apache License 2.0',
description = 'Token labeling for training image models',
description = 'Token Labeling Toolbox for training image models',
author = 'Zihang Jiang',
author_email = 'jzh0103@gmail.com',
url = 'https://github.com/zihangJiang/TokenLabeling',
Expand Down
1 change: 0 additions & 1 deletion tlmm/utils/__init__.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions tlt/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import load_pretrained_weights
2 changes: 1 addition & 1 deletion tlmm/utils/utils.py → tlt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def load_state_dict(checkpoint_path,model, use_ema=False, num_classes=1000):
raise FileNotFoundError()


def load_for_transfer_learning(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000):
def load_pretrained_weights(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000):
state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes)
model.load_state_dict(state_dict, strict=strict)

Expand Down
2 changes: 1 addition & 1 deletion validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from timm.models.helpers import load_state_dict
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
import tlmm.models
import tlt.models

has_apex = False
try:
Expand Down

0 comments on commit aa438ef

Please sign in to comment.