Skip to content

Commit

Permalink
readme
Browse files Browse the repository at this point in the history
  • Loading branch information
陈科研 committed Nov 26, 2023
1 parent 5642e49 commit bebe148
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 17 deletions.
34 changes: 30 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ If you find this project helpful, please give us a star ⭐️, your support is

🌟 **2023.11.25** Updated the code of RSPrompter, which is completely consistent with the API interface and usage method of MMDetection.

🌟 **2023.11.26** Added the LoRA efficient fine-tuning method, and made the input image size variable, reducing the memory usage of the model.

## TODO

- [X] Consistent API interface and usage method with MMDetection
- [ ] Reduce the memory usage of the model while ensuring performance by reducing the image input and combining with the large model fine-tuning technology
- [ ] Dynamically variable image size input
- [ ] Efficient fine-tuning method in the model
- [X] Reduce the memory usage of the model while ensuring performance by reducing the image input and combining with the large model fine-tuning technology
- [X] Dynamically variable image size input
- [X] Efficient fine-tuning method in the model
- [ ] Add SAM-cls model

## Table of Contents
Expand Down Expand Up @@ -137,7 +139,7 @@ mim install "mmcv>=2.0.0"
**Step 4**: Install other dependencies.

```shell
pip install -U transformers wandb einops pycocotools shapely scipy terminaltables importlib
pip install -U transformers wandb einops pycocotools shapely scipy terminaltables importlib peft mat4py
```

**Step 5**: [Optional] Install DeepSpeed.
Expand Down Expand Up @@ -330,6 +332,30 @@ python zero_to_fp32.py . $SAVE_CHECKPOINT_NAME -t $CHECKPOINT_DIR # $SAVE_CHECK
- Change `runner_type` in the Config file to `Runner`.
- Use the method of MMDetection to evaluate, and you can get the evaluation results.


### 3. About resource consumption

Here we list the resource consumption of using different models for your reference.

| Model Name | Backbone | Image Size | GPU | Batch Size | Acceleration Strategy | Single Card Memory Usage |
|:---------------------:|:--------:| :------: |:------------------:|:----------:|:----------:|:-------:|
| SAM-seg (Mask R-CNN) | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 8 | AMP FP16 | 19.4 GB |
| SAM-seg (Mask2Former) | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 8 | AMP FP16 | 21.5 GB |
| SAM-det | ResNet50 | 1024x1024 | 1x RTX 4090 24G | 8 | FP32 | 16.6 GB |
| RSPrompter-anchor | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 2 | AMP FP16 | 20.9 GB |
| RSPrompter-query | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 1 | AMP FP16 | OOM |
| RSPrompter-query | ViT-B/16 | 1024x1024 | 8x NVIDIA A100 40G | 1 | ZeRO-2 | 39.6 GB |
| RSPrompter-anchor | ViT-B/16 | 512x512 | 8x RTX 4090 24G | 4 | AMP FP16 | 20.9 GB |
| RSPrompter-query | ViT-B/16 | 512x512 | 8x RTX 4090 24G | 2 | ZeRO-2 | 21.1 GB |


Note: Low-resolution input images can effectively reduce the memory usage of the model, but their actual performance has not been verified. For details, please refer to [Config file](configs/rsprompter/rsprompter_query-nwpu-peft-512.py).

### 4. Solution to dist_train.sh: Bad substitution

If you encounter the error `Bad substitution` when running `dist_train.sh`, please use `bash dist_train.sh` to run the script.


</details>

## Acknowledgement
Expand Down
14 changes: 9 additions & 5 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@

🌟 **2023.11.25** 更新了RSPrompter的代码,完全与MMDetection保持一致的API接口及使用方法。

🌟 **2023.11.26** 加入了LoRA高效微调方法,并使得输入图像尺寸可变,减小了模型的显存占用。


## TODO

- [X] 与MMDetection保持一致的API接口及使用方法
- [ ] 通过减小图像输入并结合大模型微调技术在保证性能的同时减小模型的显存占用
- [ ] 动态可变的图像尺寸输入
- [ ] 在模型中加入高效微调的方法
- [X] 通过减小图像输入并结合大模型微调技术在保证性能的同时减小模型的显存占用
- [X] 动态可变的图像尺寸输入
- [X] 在模型中加入高效微调的方法
- [ ] 加入SAM-cls模型

## 目录
Expand Down Expand Up @@ -138,7 +140,7 @@ mim install "mmcv>=2.0.0"
**步骤 4**:安装其他依赖项。

```shell
pip install -U transformers wandb einops pycocotools shapely scipy terminaltables importlib
pip install -U transformers wandb einops pycocotools shapely scipy terminaltables importlib peft mat4py
```

**步骤 5**[可选] 安装 DeepSpeed。
Expand Down Expand Up @@ -341,8 +343,10 @@ python zero_to_fp32.py . $SAVE_CHECKPOINT_NAME -t $CHECKPOINT_DIR # $SAVE_CHECK
| RSPrompter-anchor | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 2 | AMP FP16 | 20.9 GB |
| RSPrompter-query | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 1 | AMP FP16 | OOM |
| RSPrompter-query | ViT-B/16 | 1024x1024 | 8x NVIDIA A100 40G | 1 | ZeRO-2 | 39.6 GB |
| RSPrompter-anchor | ViT-B/16 | 512x512 | 8x RTX 4090 24G | 4 | AMP FP16 | 20.9 GB |
| RSPrompter-query | ViT-B/16 | 512x512 | 8x RTX 4090 24G | 2 | ZeRO-2 | 21.1 GB |


注解:低分辨率输入图像可以有效减小模型的显存占用,但是其实际性能并未验证,具体见[配置文件](configs/rsprompter/rsprompter_query-nwpu-peft-512.py)


### 4. dist_train.sh: Bad substitution的解决
Expand Down
8 changes: 4 additions & 4 deletions configs/rsprompter/rsprompter_anchor-nwpu-peft-512.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

vis_backends = [dict(type='LocalVisBackend'),
# dict(type='WandbVisBackend', init_kwargs=dict(project='rsprompter-nwpu', group='rsprompter-query', name="rsprompter_anchor-nwpu-peft-512"))
dict(type='WandbVisBackend', init_kwargs=dict(project='rsprompter-nwpu', group='rsprompter-query', name="rsprompter_anchor-nwpu-peft-512"))
]
visualizer = dict(
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
Expand Down Expand Up @@ -136,10 +136,10 @@

dataset_type = 'NWPUInsSegDataset'
#### should be changed align with your code root and data root
code_root = '/mnt/search01/usr/chenkeyan/codes/mm_rsprompter'
data_root = '/mnt/search01/dataset/cky_data/NWPU10'
code_root = '/mnt/home/cky/RSPrompter'
data_root = '/mnt/home/cky/data/NWPU'

batch_size_per_gpu = 8
batch_size_per_gpu = 4
num_workers = 8
persistent_workers = True
train_dataloader = dict(
Expand Down
8 changes: 4 additions & 4 deletions configs/rsprompter/rsprompter_query-nwpu-peft-512.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

vis_backends = [dict(type='LocalVisBackend'),
# dict(type='WandbVisBackend', init_kwargs=dict(project='rsprompter-nwpu', group='rsprompter-query', name="rsprompter_query-nwpu-peft-512"))
dict(type='WandbVisBackend', init_kwargs=dict(project='rsprompter-nwpu', group='rsprompter-query', name="rsprompter_query-nwpu-peft-512"))
]
visualizer = dict(
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
Expand Down Expand Up @@ -142,10 +142,10 @@

dataset_type = 'NWPUInsSegDataset'
#### should be changed align with your code root and data root
code_root = '/mnt/search01/usr/chenkeyan/codes/mm_rsprompter'
data_root = '/mnt/search01/dataset/cky_data/NWPU10'
code_root = '/mnt/home/cky/RSPrompter'
data_root = '/mnt/home/cky/data/NWPU'

batch_size_per_gpu = 4
batch_size_per_gpu = 2
num_workers = 8
persistent_workers = True
train_dataloader = dict(
Expand Down

0 comments on commit bebe148

Please sign in to comment.