Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Auto Conjugate Gradient Attack #2028

Merged
merged 37 commits into from
Mar 11, 2023

Conversation

yamamura-k
Copy link
Contributor

Description

I implemented a new attack method Auto Conjugate Gradient attack proposed in "Diversified Adversarial Attacks based on Conjugate Gradient Method", ICML2022. paper link(arxiv)

This implementation works poor when the batch size is greater than or equal to 2 because the way to treat the loss and step size condition is different from the original implementation. The implementation of APGD also has the same issue, and to fix this, we have to modify the wrapper class of the threat model (e.g. PyTorchClassifier). Due to this reason, we imitate the implementation of APGD to avoid large modification.

Type of change

Please check all relevant options.

  • Improvement (non-breaking)
  • Bug fix (non-breaking)
  • New feature (non-breaking)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Testing

The test of ACG is the same to that of APGD because ACG is very similar to APGD.

  • Whether the attack works correctly (tests/attacks/evasion/test_auto_conjugate_gradient.py)

Test Configuration:

  • OS: ubuntu 22.04
  • Python version: 3.9.13
  • ART version or commit number: 11,126
  • TensorFlow / Keras / PyTorch / MXNet version: 2.10.1 / 2.10.0 / 1.13.1 / 1.8.0.post0

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: yamamura-k <yamayama23bb@gmail.com>
Signed-off-by: yamamura-k <yamayama23bb@gmail.com>
Signed-off-by: yamamura-k <yamayama23bb@gmail.com>
Signed-off-by: yamamura-k <yamayama23bb@gmail.com>
Signed-off-by: yamamura-k <yamayama23bb@gmail.com>
@beat-buesser beat-buesser self-requested a review February 15, 2023 20:07
@beat-buesser beat-buesser self-assigned this Feb 15, 2023
@beat-buesser
Copy link
Collaborator

Hi @yamamura-k Thank you very much for contributing your attack to ART and congratulations for your ICML paper! I think your attack will be very useful for many users of ART.

What kind of changes would be required to the ART estimators to support batches of size 2 or larger with the best attack performance? Maybe we can add the required functionality.

@codecov-commenter
Copy link

codecov-commenter commented Feb 15, 2023

Codecov Report

Merging #2028 (09281d0) into dev_1.14.0 (9b2891b) will increase coverage by 1.39%.
The diff coverage is 86.19%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Impacted file tree graph

@@              Coverage Diff               @@
##           dev_1.14.0    #2028      +/-   ##
==============================================
+ Coverage       84.19%   85.59%   +1.39%     
==============================================
  Files             292      293       +1     
  Lines           25798    26158     +360     
  Branches         4665     4733      +68     
==============================================
+ Hits            21720    22389     +669     
+ Misses           2882     2554     -328     
- Partials         1196     1215      +19     
Impacted Files Coverage Δ
...attacks/evasion/auto_projected_gradient_descent.py 85.76% <86.00%> (-0.52%) ⬇️
art/attacks/evasion/auto_conjugate_gradient.py 86.20% <86.20%> (ø)
art/attacks/evasion/__init__.py 98.21% <100.00%> (+0.03%) ⬆️

... and 17 files with indirect coverage changes

@beat-buesser beat-buesser added the enhancement New feature or request label Feb 15, 2023
Comment on lines +632 to +633
# if self.loss_type not in self._predefined_losses:
# raise ValueError("The argument loss_type has to be either {}.".format(self._predefined_losses))

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
@yamamura-k
Copy link
Contributor Author

What kind of changes would be required to the ART estimators to support batches of size 2 or larger with the best attack performance? Maybe we can add the required functionality.

@beat-buesser
Thank you for considering adding this feature to estimator.
In order to improve performance, we would like to calculate the condition for the step size update decision for each image. Therefore, we want to use the objective function value for each image in the batch, not the reduced (average or sum) value.

@beat-buesser
Copy link
Collaborator

Hi @yamamura-k
The method compute_loss of the classification estimators provides an option to define the type of reduction on the loss values. Currently it is set to reduction="mean" resulting in the average loss of all samples in the batch. This option can be set to reduction="none" to return the loss for each sample separately. Would this solve the issue?

@beat-buesser beat-buesser added this to the ART 1.14.0 milestone Feb 16, 2023
@yamamura-k
Copy link
Contributor Author

@beat-buesser Thank you for your comment. I think your suggestion can solve my problem. I misunderstood that the reduction affects the behavior of loss_gradient method. Thank you again for your suggestion.

I will fix my implementation later. Also, I can modify the implementation of auto_projected_gradient_descent.py at the same time. If this modification is helpful to you, I will modify the implementation of auto_projected_gradient_descent.py and send another pull request.

@yamamura-k
Copy link
Contributor Author

@beat-buesser Does the loss_gradient function work when the output of the loss is not scalar? I checked the implementation again, and I found that the output of the loss function class always returns the reduced value. So if the answer of the question is no, I think we have to change the implementation of loss_gradient or another related part.

@beat-buesser
Copy link
Collaborator

beat-buesser commented Feb 16, 2023

Hi @yamamura-k I think it would be amazing if you could upgrade your ACG and ART's APGD attacks to account for per sample step sizes! I think you are asking important questions, this is what I think:

  • reduction="none" would have to be applied in line 522, 523, and 555 to get per sample losses for the algorithms of ACG and APGD
  • Does the loss_gradient function work when the output of the loss is not scalar?

    • I understand this question affects line 485 of the ACG attack and similarly APGD. You are right, the method loss_gradient calculates the gradients of the average loss of a batch. But because of the chain rules for derivatives of sum and division (averaging of the loss) the gradients backpropagated from the average loss per batch should still result in per sample gradients independent of the other samples in the batch. In addition the per sample gradients are normalized depending on the selected norm after calling loss_gradient. Based on this I think we should not require any changes to loss_gradient for ACM and upgraded APGD.

What do you think?

@yamamura-k
Copy link
Contributor Author

@beat-buesser Thank you for your comments and suggestions.

Based on this I think we should not require any changes to loss_gradient for ACM and upgraded APGD.

I understand your opinion, and basically agree with you. However, the problem is how to satisfy the following conflicting requirements.

  • classifier._loss should return a vector of losses to get the loss values per samples.
  • classifier._loss should return the averaged losses to calculate the gradient.

I think reduction="none" cannot solve this problem directly because ART classifiers assume the output of loss is averaged over the batch (e.g. lines 843 to 856 in adversarial-robustness-toolbox/art/estimators/classification/pytorch.py). That is, the output of loss should be averaged in the definition of loss function class (like DifferenceLogitsRatioPyTorch) to calculate the gradient in the current implementation of loss_grad.

  • One possible solution is averaging the output of classifier._loss when the shape > (1, ) in loss_gradient function.
    The code below is the example of this solution. The similar modification should be applied to the other files.
    loss = self._loss(model_outputs[-1], labels_t) #line 843 in adversarial-robustness-toolbox/art/estimators/classification/pytorch.py
    # My suggestion
    if len(loss.shape) == 1 and loss.shape[0] > 1:
            loss = loss.mean() # reduce the loss
    elif len(loss.shape) > 1:
            raise ValueError
    # !My suggestion
    # Clean gradients
    self._model.zero_grad()

    # Compute gradients
    if self._use_amp:  # pragma: no cover
        from apex import amp  # pylint: disable=E0611

        with amp.scale_loss(loss, self._optimizer) as scaled_loss:
            scaled_loss.backward()

    else:
        loss.backward()
  • Another possible solution is to pass the keyword argument reduction to the classifier._loss and reduce the loss value in classifier._loss according to reduction keyword. (currently, user-specified reduction is applied in compute_loss but it seems not to work because the classifier._loss returns the mean of loss value.)

These are what I'm thinking about. Would you tell me your opinion?

@beat-buesser
Copy link
Collaborator

beat-buesser commented Feb 17, 2023

Hi @yamamura-k

I think your considerations are correct and important.

ART expects classifier._loss to be the reduced loss (average, sum, etc.) across the batch as a single value.

Your last proposal of passing the keyword argument reduction to the classifier._loss is a good idea and I think we are doing this already in classifier.compute_loss in lines

prev_reduction = self._loss.reduction
# Return individual loss values
self._loss.reduction = reduction
loss = self._loss(model_outputs[-1], labels_t)
self._loss.reduction = prev_reduction

where we save the original reduction option to prev_reduction and set the reduction option defined in the keyword argument with self._loss.reduction = reduction. In the next lines we revert self._loss.reduction to the original reduction option.

I have run a short test based on one of ART's example scripts which trains a PyTorchClassifier and runs/prints compute_loss with both options mean and none:

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from art.attacks.evasion import FastGradientMethod
from art.estimators.classification import PyTorchClassifier
from art.utils import load_mnist


# Step 0: Define the neural network model, return logits instead of activation in forward method


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=1)
        self.conv_2 = nn.Conv2d(in_channels=4, out_channels=10, kernel_size=5, stride=1)
        self.fc_1 = nn.Linear(in_features=4 * 4 * 10, out_features=100)
        self.fc_2 = nn.Linear(in_features=100, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv_1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 10)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x


# Step 1: Load the MNIST dataset

(x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_mnist()

# Step 1a: Swap axes to PyTorch's NCHW format

x_train = np.transpose(x_train, (0, 3, 1, 2)).astype(np.float32)
x_test = np.transpose(x_test, (0, 3, 1, 2)).astype(np.float32)

# Step 2: Create the model

model = Net()

# Step 2a: Define the loss function and the optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Step 3: Create the ART classifier

classifier = PyTorchClassifier(
    model=model,
    clip_values=(min_pixel_value, max_pixel_value),
    loss=criterion,
    optimizer=optimizer,
    input_shape=(1, 28, 28),
    nb_classes=10,
)

# Step 4: Train the ART classifier

classifier.fit(x_train, y_train, batch_size=64, nb_epochs=3)

print("mean")
print(classifier.compute_loss(x=x_test[:10], y=y_test[:10], reduction="mean"))

print("none")
print(classifier.compute_loss(x=x_test[:10], y=y_test[:10], reduction="none"))

This script should print something like

mean
0.011778888
none
[ 7.0333235e-06  1.9073468e-06  2.7418098e-06  1.0013530e-05
  1.1394425e-02 -0.0000000e+00  2.2188823e-03  2.7656173e-05
  1.0411299e-01  1.3232144e-05]

Would this provide the required loss values for ACG attack?

@yamamura-k
Copy link
Contributor Author

@beat-buesser Thank you for your explanation about the functionality of ART. I think the explained functionality satisfies my requirements. Then I will try my second proposal and push the commits later.

Signed-off-by: yamamura-k <yamayama23bb@gmail.com>
…able the image-wise stepsize update

Signed-off-by: yamamura-k <yamayama23bb@gmail.com>
@yamamura-k
Copy link
Contributor Author

yamamura-k commented Feb 17, 2023

@beat-buesser I modified the implementation of ACG and APGD to enable the image-wise stepsize updates in this pull request. Thank you for your patience and valuable comments and suggestions.
I confirmed that the current implementation showed the similar attack performance to the original implementation.

@beat-buesser
Copy link
Collaborator

Hi @yamamura-k I think the proposed changes above should fix the style checks. What do you think?

Signed-off-by: yamamura-k <yamayama23bb@gmail.com>
@yamamura-k
Copy link
Contributor Author

Hi @beat-buesser, the codes changed by your suggestion fix the style checks, and no new warnings are raised. Thank you so much!

@yamamura-k yamamura-k closed this Feb 28, 2023
@yamamura-k yamamura-k reopened this Feb 28, 2023
@beat-buesser
Copy link
Collaborator

Hi @yamamura-k Thank you very much for contributing your attack Auto Conjugate Gradient Attack (ICML 2022) to ART! It will be part of ART 1.14!

@beat-buesser beat-buesser merged commit 0a0a701 into Trusted-AI:dev_1.14.0 Mar 11, 2023
@yamamura-k
Copy link
Contributor Author

@beat-buesser Thank you very much for your patience in working with me to modify the code!

sheatsley added a commit to sheatsley/adversarial-robustness-toolbox that referenced this pull request Mar 24, 2023
…I#2028 because pytorch estimators do not recognize the class created in AGPD and thus do not correctly handle the labels anymore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants