Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
陈科研 committed Dec 23, 2023
1 parent f72cc8b commit f84519b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 94 deletions.
138 changes: 48 additions & 90 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,73 +137,61 @@ pip install -U wandb einops importlib peft scipy ftfy prettytable torchmetrics

</details>

### 安装 RSPrompter
### 安装 TTP

下载或克隆 RSPrompter 仓库即可。
下载或克隆 TTP 仓库即可。

```shell
git clone git@github.com:KyanChen/RSPrompter.git
cd RSPrompter
git clone git@github.com:KyanChen/TTP.git
cd TTP
```

## 数据集准备

<details>

### 基础实例分割数据集
### Levir-CD变化检测数据集

我们提供论文中使用的实例分割数据集准备方法。
#### 数据下载

#### WHU 建筑物数据集
- 图片及标签下载地址: [Levir-CD](https://chenhao.in/LEVIR/)

- 图片下载地址: [WHU建筑物数据集](https://aistudio.baidu.com/datasetdetail/56502)

- 语义标签转实例标签:我们提供了相应的[转换脚本](tools/rsprompter/whu2coco.py)来将 WHU 建筑物数据集的语义标签转换为实例标签。

#### NWPU VHR-10 数据集

- 图片下载地址: [NWPU VHR-10数据集](https://aistudio.baidu.com/datasetdetail/52812)

- 实例标签下载地址: [NWPU VHR-10实例标签](https://github.com/chaozhong2010/VHR-10_dataset_coco)

#### SSDD 数据集

- 图片下载地址: [SSDD数据集](https://aistudio.baidu.com/datasetdetail/56503)

- 实例标签下载地址: [SSDD实例标签](https://github.com/chaozhong2010/VHR-10_dataset_coco)

**注解**:在本项目的 `data` 文件夹中,我们提供了上述数据集的实例标签,你可以直接使用。

#### 组织方式

你也可以选择其他来源进行数据的下载,但是需要将数据集组织成如下的格式:

```
${DATASET_ROOT} # 数据集根目录,例如:/home/username/data/NWPU
├── annotations
│ ├── train.json
│ ├── val.json
│ └── test.json
└── images
├── train
├── val
└── test
${DATASET_ROOT} # 数据集根目录,例如:/home/username/data/levir-cd
├── train
│ ├── A
│ ├── B
│ └── label
├── val
│ ├── A
│ ├── B
│ └── label
└── test
├── A
├── B
└── label
```

注解:在项目文件夹中,我们提供了一个名为 `data` 的文件夹,其中包含了上述数据集的组织方式的示例。

### 其他数据集

如果你想使用其他数据集,可以参考 [MMDetection 文档](https://mmdetection.readthedocs.io/zh-cn/latest/user_guides/dataset_prepare.html) 来进行数据集的准备。
如果你想使用其他数据集,可以参考 [MMSegmentation 文档](https://mmsegmentation.readthedocs.io/zh-cn/latest/user_guides/2_dataset_prepare.html) 来进行数据集的准备。

</details>

## 模型训练

### SAM-based 模型
### TTP 模型

#### Config 文件及主要参数解析

我们提供了论文中使用的 SAM-based 模型的配置文件,你可以在 `configs/rsprompter` 文件夹中找到它们。Config 文件完全与 MMDetection 保持一致的 API 接口及使用方法。下面我们提供了一些主要参数的解析。如果你想了解更多参数的含义,可以参考 [MMDetection 文档](https://mmdetection.readthedocs.io/zh-cn/latest/user_guides/config.html)
我们提供了论文中使用的 TTP 模型的配置文件,你可以在 `configs/TTP` 文件夹中找到它们。Config 文件完全与 MMSegmentation 保持一致的 API 接口及使用方法。下面我们提供了一些主要参数的解析。如果你想了解更多参数的含义,可以参考 [MMSegmentation 文档](https://mmsegmentation.readthedocs.io/zh-cn/latest/user_guides/1_config.html)

<details>

Expand All @@ -213,43 +201,35 @@ ${DATASET_ROOT} # 数据集根目录,例如:/home/username/data/NWPU
- `default_hooks-CheckpointHook`:模型训练过程中的检查点保存配置,一般不需要修改。
- `default_hooks-visualization`:模型训练过程中的可视化配置,**训练时注释,测试时取消注释**
- `vis_backends-WandbVisBackend`:网络端可视化工具的配置,**打开注释后,需要在 `wandb` 官网上注册账号,可以在网络浏览器中查看训练过程中的可视化结果**
- `num_classes`:数据集的类别数,**需要根据数据集的类别数进行修改**
- `prompt_shape`:Prompt 的形状,第一个参数代表 $N_p$,第二个参数代表 $K_p$,一般不需要修改。
- `hf_sam_pretrain_name`:HuggingFace Spaces 上的 SAM 模型的名称,一般不需要修改。
- `hf_sam_pretrain_ckpt_path`:HuggingFace Spaces 上的 SAM 模型的检查点路径,**需要修改为你自己的路径**,可以使用[下载脚本](tools/rsprompter/download_hf_sam_pretrain_ckpt.py)来下载。
- `model-decoder_freeze`:是否冻结SAM解码器的参数,一般不需要修改。
- `model-neck-feature_aggregator-hidden_channels`:特征聚合器的隐藏通道数,一般不需要修改。
- `model-neck-feature_aggregator-select_layers`:特征聚合器的选择层数,**需要根据选择的SAM骨干类型进行修改**
- `model-mask_head-with_sincos`:是否在预测提示时使用 sin 正则化,一般不需要修改。
- `sam_pretrain_ckpt_path`:MMPretrain 提供的 SAM 主干的检查点路径,参考[下载地址](https://github.com/open-mmlab/mmpretrain/tree/main/configs/sam)
- `model-backbone-peft_cfg`:是否引入微调参数,一般不需要修改。
- `dataset_type`:数据集的类型,**需要根据数据集的类型进行修改**
- `code_root`:代码根目录,**修改为本项目根目录的绝对路径**
- `data_root`:数据集根目录,**修改为数据集根目录的绝对路径**
- `batch_size_per_gpu`:单卡的 batch size,**需要根据显存大小进行修改**
- `resume`: 是否断点续训,一般不需要修改。
- `load_from`:模型的预训练的检查点路径,一般不需要修改。
- `max_epochs`:最大训练轮数,一般不需要修改。
- `runner_type`:训练器的类型,需要和`optim_wrapper``strategy`的类型保持一致,一般不需要修改。

</details>


#### 单卡训练

```shell
python tools/train.py configs/rsprompter/xxx.py # xxx.py 为你想要使用的配置文件
python tools/train.py configs/TTP/xxx.py # xxx.py 为你想要使用的配置文件
```

#### 多卡训练

```shell
sh ./tools/dist_train.sh configs/rsprompter/xxx.py ${GPU_NUM} # xxx.py 为你想要使用的配置文件,GPU_NUM 为使用的 GPU 数量
sh ./tools/dist_train.sh configs/TTP/xxx.py ${GPU_NUM} # xxx.py 为你想要使用的配置文件,GPU_NUM 为使用的 GPU 数量
```

### 其他实例分割模型

<details>

如果你想使用其他实例分割模型,可以参考 [MMDetection](https://github.com/open-mmlab/mmdetection/tree/main) 来进行模型的训练,也可以将其Config文件放入本项目的 `configs` 文件夹中,然后按照上述的方法进行训练。
如果你想使用其他变化检测模型,可以参考 [Open-CD](https://github.com/likyoo/open-cd) 来进行模型的训练,也可以将其Config文件放入本项目的 `configs` 文件夹中,然后按照上述的方法进行训练。

</details>

Expand All @@ -258,13 +238,13 @@ sh ./tools/dist_train.sh configs/rsprompter/xxx.py ${GPU_NUM} # xxx.py 为你
#### 单卡测试:

```shell
python tools/test.py configs/rsprompter/xxx.py ${CHECKPOINT_FILE} # xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件
python tools/test.py configs/TTP/xxx.py ${CHECKPOINT_FILE} # xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件
```

#### 多卡测试:

```shell
sh ./tools/dist_test.sh configs/rsprompter/xxx.py ${CHECKPOINT_FILE} ${GPU_NUM} # xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件,GPU_NUM 为使用的 GPU 数量
sh ./tools/dist_test.sh configs/TTP/xxx.py ${CHECKPOINT_FILE} ${GPU_NUM} # xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件,GPU_NUM 为使用的 GPU 数量
```

**注解**:如果需要获取可视化结果,可以在 Config 文件中取消 `default_hooks-visualization` 的注释。
Expand All @@ -275,13 +255,13 @@ sh ./tools/dist_test.sh configs/rsprompter/xxx.py ${CHECKPOINT_FILE} ${GPU_NUM}
#### 单张图像预测:

```shell
python demo/image_demo.py ${IMAGE_FILE} configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR} # IMAGE_FILE 为你想要预测的图像文件,xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件,OUTPUT_DIR 为预测结果的输出路径
python demo/image_demo.py ${IMAGE_FILE} configs/TTP/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR} # IMAGE_FILE 为你想要预测的图像文件,xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件,OUTPUT_DIR 为预测结果的输出路径
```

#### 多张图像预测:

```shell
python demo/image_demo.py ${IMAGE_DIR} configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR} # IMAGE_DIR 为你想要预测的图像文件夹,xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件,OUTPUT_DIR 为预测结果的输出路径
python demo/image_demo.py ${IMAGE_DIR} configs/TTP/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR} # IMAGE_DIR 为你想要预测的图像文件夹,xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件,OUTPUT_DIR 为预测结果的输出路径
```


Expand All @@ -290,44 +270,26 @@ python demo/image_demo.py ${IMAGE_DIR} configs/rsprompter/xxx.py --weights ${CH

<details>

我们在这里列出了使用时的一些常见问题及其相应的解决方案。如果您发现有一些问题被遗漏,请随时提 PR 丰富这个列表。如果您无法在此获得帮助,请使用[issue](https://github.com/KyanChen/RSPrompter/issues)来寻求帮助。请在模板中填写所有必填信息,这有助于我们更快定位问题。
我们在这里列出了使用时的一些常见问题及其相应的解决方案。如果您发现有一些问题被遗漏,请随时提 PR 丰富这个列表。如果您无法在此获得帮助,请使用[issue](https://github.com/KyanChen/TTP/issues)来寻求帮助。请在模板中填写所有必填信息,这有助于我们更快定位问题。

### 1. 是否需要安装MMDetection
### 1. 是否需要安装MMSegmentation,MMPretrain,MMDet,Open-CD

我们建议您不要安装MMDetection,因为我们已经对MMDetection的代码进行了部分修改,如果您安装了MMDetection,可能会导致代码运行出错。如果你出现了模块尚未被注册的错误,请检查:
我们建议您不要安装它们,因为我们已经对它们的代码进行了部分修改,如果您安装了它们,可能会导致代码运行出错。如果你出现了模块尚未被注册的错误,请检查:

- 是否安装了MMDetection,若有则卸载
- 是否安装了这些库,若有则卸载
- 是否在类名前加上了`@MODELS.register_module()`,若没有则加上
- 是否在`__init__.py`中加入了`from .xxx import xxx`,若没有则加上
- 是否在Config文件中加入了`custom_imports = dict(imports=['mmdet.rsprompter'], allow_failed_imports=False)`,若没有则加上

### 2. 使用DeepSpeed训练时后如何评测模型?

我们建议您使用DeepSpeed训练模型,因为DeepSpeed可以大幅度提升模型的训练速度。但是,DeepSpeed的训练方式与MMDetection的训练方式不同,因此在使用DeepSpeed训练模型后,需要使用MMDetection的方式进行评测。具体来说,您需要:

- 将DeepSpeed训练的模型转换为MMDetection的模型,进入到存储模型的文件夹,运行
```shell
python zero_to_fp32.py . $SAVE_CHECKPOINT_NAME -t $CHECKPOINT_DIR # $SAVE_CHECKPOINT_NAME为转换后的模型名称,$CHECKPOINT_DIR为DeepSpeed训练的模型名称
```
- 将Config文件中的`runner_type`改为`Runner`
- 使用MMDetection的方式进行评测,即可得到评测结果。
- 是否在Config文件中加入了`custom_imports = dict(imports=['mmseg.ttp'], allow_failed_imports=False)`,若没有则加上

### 3. 关于资源消耗情况

这里我们列出了使用不同模型的资源消耗情况,供您参考。
### 2. 关于资源消耗情况

| 模型名称 | 骨干网络类型 | 图像尺寸 | GPU | Batch Size | 加速策略 | 单卡显存占用 |
|:---------------------:|:--------:| :------: |:------------------:|:----------:|:----------:|:-------:|
| SAM-seg (Mask R-CNN) | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 8 | AMP FP16 | 14 GB | 3H
| SAM-seg (Mask2Former) | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 8 | AMP FP16 | 12 GB | 2H
| SAM-det | ResNet50 | 1024x1024 | 1x RTX 4090 24G | 8 | FP32 | 16.6 GB |
| RSPrompter-anchor | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 2 | AMP FP16 | 20.9 GB |
| RSPrompter-query | ViT-B/16 | 1024x1024 | 1x RTX 4090 24G | 1 | AMP FP16 | OOM |
| RSPrompter-query | ViT-B/16 | 1024x1024 | 8x NVIDIA A100 40G | 1 | ZeRO-2 | 39.6 GB |
| RSPrompter-anchor | ViT-B/16 | 512x512 | 8x RTX 4090 24G | 4 | AMP FP16 | 20.2 GB |
| RSPrompter-query | ViT-B/16 | 512x512 | 8x RTX 4090 24G | 2 | ZeRO-2 | 21.1 GB |
这里我们列出了使用不同训练方法的资源消耗情况,供您参考。

注解:低分辨率输入图像可以有效减小模型的显存占用,但是其实际性能并未验证,具体见[配置文件](configs/rsprompter/rsprompter_query-nwpu-peft-512.py)
| 模型名称 | 骨干网络类型 | 图像尺寸 | GPU | Batch Size | 加速策略 | 单卡显存占用 | 训练时间 |
|:----:|:--------:|:-------:|:---------------:|:----------:|:----:|:-------:|:----:|
| TTP | ViT-L/16 | 512x512 | 4x RTX 4090 24G | 2 | FP32 | 14 GB | 3H |
| TTP | ViT-L/16 | 512x512 | 4x RTX 4090 24G | 2 | FP16 | 12 GB | 2H |


### 4. dist_train.sh: Bad substitution的解决
Expand All @@ -337,23 +299,19 @@ python zero_to_fp32.py . $SAVE_CHECKPOINT_NAME -t $CHECKPOINT_DIR # $SAVE_CHECK

### 5. You should set `PYTHONPATH` to make `sys.path` include the directory which contains your custom module

请查看详细的报错信息,一般是某些依赖包没有安装,请使用`pip install`来安装依赖包。
</details>

## 致谢

本项目基于 [MMDetection](https://github.com/open-mmlab/mmdetection/tree/main) 项目进行开发,感谢 MMDetection 项目的开发者们。
本项目基于 [MMSegmentation](https://github.com/open-mmlab/mmsegmentation)[Open-CD](https://github.com/likyoo/open-cd) 项目进行开发,感谢 MMSegmentation 和 Open-CD 项目的开发者们。

## 引用

如果你在研究中使用了本项目的代码或者性能基准,请参考如下 bibtex 引用 RSPrompter
如果你在研究中使用了本项目的代码或者性能基准,请参考如下 bibtex 引用 TTP

```
@article{chen2023rsprompter,
title={RSPrompter: Learning to prompt for remote sensing instance segmentation based on visual foundation model},
author={Chen, Keyan and Liu, Chenyang and Chen, Hao and Zhang, Haotian and Li, Wenyuan and Zou, Zhengxia and Shi, Zhenwei},
journal={arXiv preprint arXiv:2306.16269},
year={2023}
}
xxx
```

## 开源许可证
Expand Down
6 changes: 2 additions & 4 deletions configs/TTP/ttp_sam_large_levircd_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=True, interval=10, save_best='cd/iou_changed', max_keep_ckpts=5, greater_keys=['cd/iou_changed'], save_last=True),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='CDVisualizationHook', interval=1,
img_shape=(1024, 1024, 3))
visualization=dict(type='CDVisualizationHook', interval=1, img_shape=(1024, 1024, 3))
)
vis_backends = [dict(type='CDLocalVisBackend'),
dict(type='WandbVisBackend',
init_kwargs=dict(project='samcd', group='levircd', name='ttp_sam_large_levircd_fp16'))
dict(type='WandbVisBackend', init_kwargs=dict(project='samcd', group='levircd', name='ttp_sam_large_levircd_fp16'))
]

visualizer = dict(
Expand Down

0 comments on commit f84519b

Please sign in to comment.