Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
陈科研 committed Nov 26, 2023
1 parent 6447e2f commit 38c2eae
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 2 deletions.
236 changes: 236 additions & 0 deletions configs/rsprompter/rsprompter_anchor-nwpu-peft-512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
_base_ = ['_base_/rsprompter_anchor.py']

work_dir = './work_dirs/rsprompter/rsprompter_anchor-nwpu-peft-512'

default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=5),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=5, save_best='coco/bbox_mAP', rule='greater', save_last=True),
sampler_seed=dict(type='DistSamplerSeedHook'),
# visualization=dict(type='DetVisualizationHook', draw=True, interval=1, test_out_dir='vis_data')
)

vis_backends = [dict(type='LocalVisBackend'),
# dict(type='WandbVisBackend', init_kwargs=dict(project='rsprompter-nwpu', group='rsprompter-query', name="rsprompter_anchor-nwpu-peft-512"))
]
visualizer = dict(
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')

num_classes = 10
prompt_shape = (70, 5) # (per img pointset, per pointset point)

#### should be changed when using different pretrain model

# sam base model
hf_sam_pretrain_name = "facebook/sam-vit-base"
hf_sam_pretrain_ckpt_path = "pretrain_models/huggingface/hub/models--facebook--sam-vit-base/snapshots/b5fc59950038394bae73f549a55a9b46bc6f3d96/pytorch_model.bin"
# # sam large model
# hf_sam_pretrain_name = "facebook/sam-vit-large"
# hf_sam_pretrain_ckpt_path = "pretrain_models/huggingface/hub/models--facebook--sam-vit-large/snapshots/70009d56dac23ebb3265377257158b1d6ed4c802/pytorch_model.bin"
# # sam huge model
# hf_sam_pretrain_name = "facebook/sam-vit-huge"
# hf_sam_pretrain_ckpt_path = "pretrain_models/huggingface/hub/models--facebook--sam-vit-huge/snapshots/89080d6dcd9a900ebd712b13ff83ecf6f072e798/pytorch_model.bin"

crop_size = (512, 512)

batch_augments = [
dict(
type='BatchFixedSizePad',
size=crop_size,
img_pad_value=0,
pad_mask=True,
mask_pad_value=0,
pad_seg=False)
]

data_preprocessor = dict(
type='DetDataPreprocessor',
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
bgr_to_rgb=True,
pad_mask=True,
pad_size_divisor=32,
batch_augments=batch_augments
)

model = dict(
decoder_freeze=False,
data_preprocessor=data_preprocessor,
shared_image_embedding=dict(
hf_pretrain_name=hf_sam_pretrain_name,
init_cfg=dict(type='Pretrained', checkpoint=hf_sam_pretrain_ckpt_path),
),
backbone=dict(
_delete_=True,
img_size=crop_size[0],
type='MMPretrainSamVisionEncoder',
hf_pretrain_name=hf_sam_pretrain_name,
init_cfg=dict(type='Pretrained', checkpoint=hf_sam_pretrain_ckpt_path),
peft_config=dict(
peft_type="LORA",
r=16,
target_modules=["qkv"],
lora_alpha=32,
lora_dropout=0.05,
bias="none",
),
),
neck=dict(
feature_aggregator=dict(
_delete_=True,
type='PseudoFeatureAggregator',
in_channels=256,
hidden_channels=512,
out_channels=256,
),
),
roi_head=dict(
bbox_head=dict(
num_classes=num_classes,
),
mask_head=dict(
mask_decoder=dict(
hf_pretrain_name=hf_sam_pretrain_name,
init_cfg=dict(type='Pretrained', checkpoint=hf_sam_pretrain_ckpt_path)
),
per_pointset_point=prompt_shape[1],
with_sincos=True,
),
),
)


backend_args = None
train_pipeline = [
dict(type='LoadImageFromFile', backend_args=backend_args, to_float32=True),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='RandomFlip', prob=0.5),
# large scale jittering
dict(
type='RandomResize',
scale=crop_size,
ratio_range=(0.1, 2.0),
resize_type='Resize',
keep_ratio=True),
dict(
type='RandomCrop',
crop_size=crop_size,
crop_type='absolute',
recompute_bbox=True,
allow_negative_crop=True),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-5, 1e-5), by_mask=True),
dict(type='PackDetInputs')
]

test_pipeline = [
dict(type='LoadImageFromFile', backend_args=backend_args, to_float32=True),
dict(type='Resize', scale=crop_size, keep_ratio=True),
dict(type='Pad', size=crop_size, pad_val=dict(img=(0.406 * 255, 0.456 * 255, 0.485 * 255), masks=0)),
# If you don't have a gt annotation, delete the pipeline
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor'))
]

dataset_type = 'NWPUInsSegDataset'
#### should be changed align with your code root and data root
code_root = '/mnt/search01/usr/chenkeyan/codes/mm_rsprompter'
data_root = '/mnt/search01/dataset/cky_data/NWPU10'

batch_size_per_gpu = 8
num_workers = 8
persistent_workers = True
train_dataloader = dict(
batch_size=batch_size_per_gpu,
num_workers=num_workers,
persistent_workers=persistent_workers,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=code_root + '/data/NWPU/annotations/NWPU_instances_train.json',
data_prefix=dict(img='imgs'),
pipeline=train_pipeline,
)
)

val_dataloader = dict(
batch_size=batch_size_per_gpu,
num_workers=num_workers,
persistent_workers=persistent_workers,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=code_root + '/data/NWPU/annotations/NWPU_instances_val.json',
data_prefix=dict(img='imgs'),
pipeline=test_pipeline,
)
)

test_dataloader = val_dataloader
resume = False
load_from = None

base_lr = 0.0002
max_epochs = 500

train_cfg = dict(max_epochs=max_epochs)
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=50),
dict(
type='CosineAnnealingLR',
eta_min=base_lr * 0.001,
begin=1,
end=max_epochs,
T_max=max_epochs,
by_epoch=True
)
]

#### AMP training config
runner_type = 'Runner'
optim_wrapper = dict(
type='AmpOptimWrapper',
dtype='float16',
optimizer=dict(
type='AdamW',
lr=base_lr,
weight_decay=0.05)
)
#
# #### DeepSpeed Configs
# runner_type = 'FlexibleRunner'
# strategy = dict(
# type='DeepSpeedStrategy',
# fp16=dict(
# enabled=True,
# auto_cast=False,
# fp16_master_weights_and_grads=False,
# loss_scale=0,
# loss_scale_window=500,
# hysteresis=2,
# min_loss_scale=1,
# initial_scale_power=15,
# ),
# gradient_clipping=0.1,
# inputs_to_half=['inputs'],
# zero_optimization=dict(
# stage=2,
# allgather_partitions=True,
# allgather_bucket_size=2e8,
# reduce_scatter=True,
# reduce_bucket_size='auto',
# overlap_comm=True,
# contiguous_gradients=True,
# ),
# )
# optim_wrapper = dict(
# type='DeepSpeedOptimWrapper',
# optimizer=dict(
# type='AdamW',
# lr=base_lr,
# weight_decay=0.05
# )
# )
2 changes: 1 addition & 1 deletion configs/rsprompter/rsprompter_anchor-nwpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
load_from = None

base_lr = 0.0002
max_epochs = 600
max_epochs = 500

train_cfg = dict(max_epochs=max_epochs)

Expand Down
2 changes: 1 addition & 1 deletion configs/rsprompter/rsprompter_query-nwpu-peft-512.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

vis_backends = [dict(type='LocalVisBackend'),
# dict(type='WandbVisBackend', init_kwargs=dict(project='rsprompter-nwpu', group='rsprompter-query', name="rsprompter_query-nwpu"))
# dict(type='WandbVisBackend', init_kwargs=dict(project='rsprompter-nwpu', group='rsprompter-query', name="rsprompter_query-nwpu-peft-512"))
]
visualizer = dict(
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
Expand Down

0 comments on commit 38c2eae

Please sign in to comment.