Skip to content

Latest commit

 

History

History
418 lines (279 loc) · 16.7 KB

README_zh-CN.md

File metadata and controls

418 lines (279 loc) · 16.7 KB

RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model





GitHub stars license arXiv Hugging Face Spaces



English | 简体中文

简介

本项目仓库是论文 RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model 的代码实现,基于 MMDetection 项目进行开发。

当前分支在PyTorch 2.x 和 CUDA 12.1 下测试通过,支持 Python 3.8+,能兼容绝大多数的 CUDA 版本。

如果你觉得本项目对你有帮助,请给我们一个 star ⭐️,你的支持是我们最大的动力。

主要特性
  • 与 MMDetection 高度保持一致的 API 接口及使用方法
  • 开源了论文中的 SAM-seg、SAM-det、RSPrompter 等模型
  • 通过了 AMP、DeepSpeed 等训练方式的测试
  • 支持了多种数据集的训练和测试

更新日志

🌟 2023.06.29 发布了 RSPrompter 项目,基于Lightning和MMDetection实现了论文中的SAM-seg、SAM-det、RSPrompter等模型。

🌟 2023.11.25 更新了RSPrompter的代码,完全与MMDetection保持一致的API接口及使用方法。

🌟 2023.11.26 加入了LoRA高效微调方法,并使得输入图像尺寸可变,减小了模型的显存占用。

🌟 2023.11.26 提供了各个模型的显存占用参考,具体见常见问题

🌟 2023.11.30 更新了论文内容,见Arxiv

TODO

  • 与MMDetection保持一致的API接口及使用方法
  • 通过减小图像输入并结合大模型微调技术在保证性能的同时减小模型的显存占用
  • 动态可变的图像尺寸输入
  • 在模型中加入高效微调的方法
  • 加入SAM-cls模型

目录

安装

依赖项

  • Linux 或 Windows
  • Python 3.8+,推荐使用 3.10
  • PyTorch 2.0 或更高版本,推荐使用 2.1
  • CUDA 11.7 或更高版本,推荐使用 12.1
  • MMCV 2.0 或更高版本,推荐使用 2.1

环境安装

推荐使用 Miniconda 来进行安装,以下命令将会创建一个名为 rsprompter 的虚拟环境,并安装 PyTorch 和 MMCV。

注解:如果你对 PyTorch 有经验并且已经安装了它,你可以直接跳转到下一小节。否则,你可以按照下述步骤进行准备。

步骤 0:安装 Miniconda

步骤 1:创建一个名为 rsprompter 的虚拟环境,并激活它。

conda create -n rsprompter python=3.10 -y
conda activate rsprompter

步骤 2:安装 PyTorch2.1.x

Linux/Windows:

pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121

或者

conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia

步骤 3:安装 MMCV2.1.x

pip install -U openmim
mim install mmcv==2.1.0

步骤 4:安装其他依赖项。

pip install -U transformers==4.38.1 wandb==0.16.3 einops pycocotools shapely scipy terminaltables importlib peft==0.8.2 mat4py==0.6.0 mpi4py

步骤 5:[可选] 安装 DeepSpeed。

如果您想使用 DeepSpeed 训练模型,您需要安装 DeepSpeed。DeepSpeed 的安装方法可以参考 DeepSpeed 官方文档

pip install deepspeed==0.13.4

注解:Windows 系统下对 DeepSpeed 的支持尚未完善,我们建议您在 Linux 系统下使用 DeepSpeed。

安装 RSPrompter

下载或克隆 RSPrompter 仓库即可。

git clone git@github.com:KyanChen/RSPrompter.git
cd RSPrompter

数据集准备

基础实例分割数据集

我们提供论文中使用的实例分割数据集准备方法。

WHU 建筑物数据集

  • 图片下载地址: WHU建筑物数据集

  • 语义标签转实例标签:我们提供了相应的转换脚本来将 WHU 建筑物数据集的语义标签转换为实例标签。

NWPU VHR-10 数据集

SSDD 数据集

注解:在本项目的 data 文件夹中,我们提供了上述数据集的实例标签,你可以直接使用。

组织方式

你也可以选择其他来源进行数据的下载,但是需要将数据集组织成如下的格式:

${DATASET_ROOT} # 数据集根目录,例如:/home/username/data/NWPU
├── annotations
│   ├── train.json
│   ├── val.json
│   └── test.json
└── images
    ├── train
    ├── val
    └── test

注解:在项目文件夹中,我们提供了一个名为 data 的文件夹,其中包含了上述数据集的组织方式的示例。

其他数据集

如果你想使用其他数据集,可以参考 MMDetection 文档 来进行数据集的准备。

模型训练

SAM-based 模型

Config 文件及主要参数解析

我们提供了论文中使用的 SAM-based 模型的配置文件,你可以在 configs/rsprompter 文件夹中找到它们。Config 文件完全与 MMDetection 保持一致的 API 接口及使用方法。下面我们提供了一些主要参数的解析。如果你想了解更多参数的含义,可以参考 MMDetection 文档

参数解析

  • work_dir:模型训练的输出路径,一般不需要修改。
  • default_hooks-CheckpointHook:模型训练过程中的检查点保存配置,一般不需要修改。
  • default_hooks-visualization:模型训练过程中的可视化配置,训练时注释,测试时取消注释
  • vis_backends-WandbVisBackend:网络端可视化工具的配置,打开注释后,需要在 wandb 官网上注册账号,可以在网络浏览器中查看训练过程中的可视化结果
  • num_classes:数据集的类别数,需要根据数据集的类别数进行修改
  • prompt_shape:Prompt 的形状,第一个参数代表 $N_p$,第二个参数代表 $K_p$,一般不需要修改。
  • hf_sam_pretrain_name:HuggingFace Spaces 上的 SAM 模型的名称,可以改为自己下载到本地的路径,下载脚本
  • hf_sam_pretrain_ckpt_path:HuggingFace Spaces 上的 SAM 模型的检查点路径,需要修改为你自己的路径,可以使用下载脚本来下载。
  • model-decoder_freeze:是否冻结SAM解码器的参数,一般不需要修改。
  • model-neck-feature_aggregator-hidden_channels:特征聚合器的隐藏通道数,一般不需要修改。
  • model-neck-feature_aggregator-select_layers:特征聚合器的选择层数,需要根据选择的SAM骨干类型进行修改
  • model-mask_head-with_sincos:是否在预测提示时使用 sin 正则化,一般不需要修改。
  • dataset_type:数据集的类型,需要根据数据集的类型进行修改
  • code_root:代码根目录,修改为本项目根目录的绝对路径
  • data_root:数据集根目录,修改为数据集根目录的绝对路径
  • batch_size_per_gpu:单卡的 batch size,需要根据显存大小进行修改
  • resume: 是否断点续训,一般不需要修改。
  • load_from:模型的预训练的检查点路径,一般不需要修改。
  • max_epochs:最大训练轮数,一般不需要修改。
  • runner_type:训练器的类型,需要和optim_wrapperstrategy的类型保持一致,一般不需要修改。

单卡训练

python tools/train.py configs/rsprompter/xxx.py  # xxx.py 为你想要使用的配置文件

多卡训练

sh ./tools/dist_train.sh configs/rsprompter/xxx.py ${GPU_NUM}  # xxx.py 为你想要使用的配置文件,GPU_NUM 为使用的 GPU 数量

其他实例分割模型

如果你想使用其他实例分割模型,可以参考 MMDetection 来进行模型的训练,也可以将其Config文件放入本项目的 configs 文件夹中,然后按照上述的方法进行训练。

模型测试

单卡测试:

python tools/test.py configs/rsprompter/xxx.py ${CHECKPOINT_FILE}  # xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件

多卡测试:

sh ./tools/dist_test.sh configs/rsprompter/xxx.py ${CHECKPOINT_FILE} ${GPU_NUM}  # xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件,GPU_NUM 为使用的 GPU 数量

注解:如果需要获取可视化结果,可以在 Config 文件中取消 default_hooks-visualization 的注释。

图像预测

单张图像预测:

python demo/image_demo.py ${IMAGE_FILE}  configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR}  # IMAGE_FILE 为你想要预测的图像文件,xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件,OUTPUT_DIR 为预测结果的输出路径

多张图像预测:

python demo/image_demo.py ${IMAGE_DIR}  configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR}  # IMAGE_DIR 为你想要预测的图像文件夹,xxx.py 为你想要使用的配置文件,CHECKPOINT_FILE 为你想要使用的检查点文件,OUTPUT_DIR 为预测结果的输出路径

常见问题

我们在这里列出了使用时的一些常见问题及其相应的解决方案。如果您发现有一些问题被遗漏,请随时提 PR 丰富这个列表。如果您无法在此获得帮助,请使用issue来寻求帮助。请在模板中填写所有必填信息,这有助于我们更快定位问题。

1. 是否需要安装MMDetection?

我们建议您不要安装MMDetection,因为我们已经对MMDetection的代码进行了部分修改,如果您安装了MMDetection,可能会导致代码运行出错。如果你出现了模块尚未被注册的错误,请检查:

  • 是否安装了MMDetection,若有则卸载
  • 是否在类名前加上了@MODELS.register_module(),若没有则加上
  • 是否在__init__.py中加入了from .xxx import xxx,若没有则加上
  • 是否在Config文件中加入了custom_imports = dict(imports=['mmdet.rsprompter'], allow_failed_imports=False),若没有则加上

2. 使用DeepSpeed训练时后如何评测模型?

我们建议您使用DeepSpeed训练模型,因为DeepSpeed可以大幅度提升模型的训练速度。但是,DeepSpeed的训练方式与MMDetection的训练方式不同,因此在使用DeepSpeed训练模型后,需要使用MMDetection的方式进行评测。具体来说,您需要:

  • 将DeepSpeed训练的模型转换为MMDetection的模型,进入到存储模型的文件夹,运行
python zero_to_fp32.py . $SAVE_CHECKPOINT_NAME -t $CHECKPOINT_DIR  # $SAVE_CHECKPOINT_NAME为转换后的模型名称,$CHECKPOINT_DIR为DeepSpeed训练的模型名称
  • 将Config文件中的runner_type改为Runner
  • 使用MMDetection的方式进行评测,即可得到评测结果。

3. 关于资源消耗情况

这里我们列出了使用不同模型的资源消耗情况,供您参考。

模型名称 骨干网络类型 图像尺寸 GPU Batch Size 加速策略 单卡显存占用
SAM-seg (Mask R-CNN) ViT-B/16 1024x1024 1x RTX 4090 24G 8 AMP FP16 19.4 GB
SAM-seg (Mask2Former) ViT-B/16 1024x1024 1x RTX 4090 24G 8 AMP FP16 21.5 GB
SAM-det ResNet50 1024x1024 1x RTX 4090 24G 8 FP32 16.6 GB
RSPrompter-anchor ViT-B/16 1024x1024 1x RTX 4090 24G 2 AMP FP16 20.9 GB
RSPrompter-query ViT-B/16 1024x1024 1x RTX 4090 24G 1 AMP FP16 OOM
RSPrompter-query ViT-B/16 1024x1024 8x NVIDIA A100 40G 1 ZeRO-2 39.6 GB
RSPrompter-anchor ViT-B/16 512x512 8x RTX 4090 24G 4 AMP FP16 20.2 GB
RSPrompter-query ViT-B/16 512x512 8x RTX 4090 24G 2 ZeRO-2 21.1 GB

注解:低分辨率输入图像可以有效减小模型的显存占用,但是其实际性能并未验证,具体见配置文件

4. dist_train.sh: Bad substitution的解决

如果您在运行dist_train.sh时出现了Bad substitution的错误,请使用bash dist_train.sh来运行脚本。

5. 无法下载访问和下载HuggingFace Spaces上的模型

如果您无法访问和下载HuggingFace Spaces上的模型,请使用下载脚本来下载。 请参考官方处理方式

6. 分割loss一直为0或者出现Nan

由于batch size过小导致训练不稳定,有以下不同的解决方案,任选其一即可:

  1. 增大batch size为2或者4(可能会出现显存不够);

  2. 采用梯度累加方法(修改Config文件中的optim_wrapper);

optim_wrapper = dict(
type='AmpOptimWrapper',
dtype='float16', # 修改为bfloat16 更稳定
optimizer=dict(
    type='AdamW',
    lr=base_lr,
    weight_decay=0.05),
accumulative_counts=4  # 需要增加的配置,修改为4或其他大于1的数
)
  1. 取消解码时Prompter中正余弦变换(修改Config文件中的with_sincos=False);

  2. 采用输入图像为512的peft配置,增大batch size。

致谢

本项目基于 MMDetection 项目进行开发,感谢 MMDetection 项目的开发者们。

引用

如果你在研究中使用了本项目的代码或者性能基准,请参考如下 bibtex 引用 RSPrompter。

@article{chen2024rsprompter,
  title={RSPrompter: Learning to prompt for remote sensing instance segmentation based on visual foundation model},
  author={Chen, Keyan and Liu, Chenyang and Chen, Hao and Zhang, Haotian and Li, Wenyuan and Zou, Zhengxia and Shi, Zhenwei},
  journal={IEEE Transactions on Geoscience and Remote Sensing},
  year={2024},
  publisher={IEEE}
}

开源许可证

该项目采用 Apache 2.0 开源许可证

联系我们

如果有其他问题❓,请及时与我们联系 👬