Skip to content

Harnessing Edge Information for Improved Robustness in Vision Transformers. In AAAI 2024.

Notifications You must be signed in to change notification settings

liyxi/RobustEdgeNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Harnessing Edge Information for Improved Robustness in Vision Transformers

Yanxi Li (University of Sydney), Chengbin Du (University of Sydney), Chang Xu (University of Sydney)

In Proceedings of the AAAI Conference on Artificial Intelligence, 2024

[Paper] [Poster]

EdgeNet Diagram

Abstract

Deep Neural Networks (DNNs) have demonstrated remarkable accuracy in vision classification tasks. However, they exhibit vulnerability to additional noises known as adversarial attacks. Previous studies hypothesize that this vulnerability might stem from the fact that high-accuracy DNNs heavily rely on irrelevant and non-robust features, such as textures and the background. In this work, we reveal that edge information extracted from images can provide relevant and robust features related to shapes and the foreground. These features assist pretrained DNNs in achieving improved adversarial robustness without compromising their accuracy on clean images. A lightweight and plug-and-play EdgeNet is proposed, which can be seamlessly integrated into existing pretrained DNNs, including Vision Transformers, a recent family of state-of-the-art models for vision classification. Our EdgeNet can process edges derived from either clean nature images or noisy adversarial images, yielding robust features which can be injected into the intermediate layers of the frozen backbone DNNs. The cost of obtaining such edges using conventional edge detection algorithms (e.g., Canny edge detector) is marginal, and the cost of training the EdgeNet is equivalent to that of fine-tuning the backbone network with techniques such as Adapter.

Requirements

pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install timm==0.6.13
pip install opencv-python==4.8.0.74

Usage

Build model

import torch

from PIL import Image
from timm.models import create_model
from timm.data import transforms as timm_transforms
from torchvision import transforms

from config import Config
from canny_edge import CannyEdge
from model import robust_vit_base_patch16_224


# get config
config = Config('rob_vit_b16.yaml')
args = config.load_config()

# get edge detector
edge_fn = CannyEdge(
    auto=args.canny_threshold_auto,
    lower=args.canny_threshold_lower,
    upper=args.canny_threshold_upper,
)

# get model
model_kwargs = dict(
    num_classes=args.num_classes,
    drop_rate=args.drop,
    drop_path_rate=args.drop_path,
    global_pool=args.gp,
    bn_momentum=args.bn_momentum,
    bn_eps=args.bn_eps,
    start_index=args.start_index,
    interval=args.interval,
    zero_conv_channel=args.zero_conv_channel
)
model = create_model(
    model_name=args.model,
    pretrained=args.pretrained,
    checkpoint_path=args.initial_checkpoint,
    scriptable=args.torchscript,
    **model_kwargs,
)

train_transforms = transforms.Compose([
    timm_transforms.RandomResizedCropAndInterpolation(
        size=(224, 224),
        scale=(0.08, 1.0),
        ratio=(0.75, 1.3333),
        interpolation='random',
    ),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(
        brightness=[0.6, 1.4],
        contrast=[0.6, 1.4],
        saturation=[0.6, 1.4],
    ),
    transforms.ToTensor(),
])

eval_transforms = transforms.Compose([
    transforms.Resize(
        size=248,
        interpolation=transforms.functional.InterpolationMode.BICUBIC, 
        max_size=None,
        antialias=None
    ),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
])

"""
NOTE: Normalization is included in the model for the convenience of the attack.
"""

model.train_transforms_norm = model.eval_transforms_norm = transforms.Compose([
    transforms.Normalize(
        mean=torch.tensor([0.5000, 0.5000, 0.5000]),
        std=torch.tensor([0.5000, 0.5000, 0.5000]),
    )
])

model.train_transforms_norm_c = model.eval_transforms_norm_c = transforms.Compose([
    transforms.Normalize(
        mean=torch.tensor([0.5000]),
        std=torch.tensor([0.5000]),
    )
])

if torch.cuda.is_available():
    print('use cuda')
    model.to('cuda')

Training forward example

model.train()

# get data
image_PIL = Image.open('data/train/n02087394_103.JPEG')
image_PIL = image_PIL.convert("RGB")
image_tensor = train_transforms(image_PIL)

# forward
image_with_edge = edge_fn(image_tensor)
image_with_edge = image_with_edge.unsqueeze(0)

x, edge = image_with_edge[:, :3], image_with_edge[:, 3:]

print(x.size(), edge.size())
# expect: 'torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])'

output = model(x, edge)
print(output.size())
# expect: 'torch.Size([1, 1000])'

Evaluation forward example

model.eval()

# get data
image_PIL = Image.open('data/val/ILSVRC2012_val_00000077.JPEG')
image_PIL = image_PIL.convert("RGB")
image_tensor = eval_transforms(image_PIL)

# forward
image_with_edge = edge_fn(image_tensor)
image_with_edge = image_with_edge.unsqueeze(0)
x, edge = image_with_edge[:, :3], image_with_edge[:, 3:]

print(x.size(), edge.size())
# expect: 'torch.Size([1, 3, 224, 224]) torch.Size([1, 1, 224, 224])'

output = model(x, edge)
print(output.size())
# expect: 'torch.Size([1, 1000])'

Canny edge detector

Edge Samples

Citation

@inproceedings{li2024harnessing,
  title={Harnessing Edge Information for Improved Robustness in Vision Transformers},
  author={Li, Yanxi and Du, Chengbin and Xu, Chang},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={38},
  number={4},
  pages={3252--3260},
  year={2024}
}

About

Harnessing Edge Information for Improved Robustness in Vision Transformers. In AAAI 2024.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages