Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding adversarial weight perturbation protocol #2224

Merged
merged 6 commits into from
Sep 8, 2023

Conversation

Zaid-Hameed
Copy link
Collaborator

Description

AWP is an important adversarial training approach because it provides better robustness against adversarial attacks and mitigates robust overfitting. AWP has been proposed in paper "Adversarial Weight Perturbation Helps
Robust Generalization".

Paper link: https://proceedings.neurips.cc/paper/2020/file/1ef91c212e30e14bf125e9374262401f-Paper.pdf

It is also a base component of more advanced adversarial training approaches.

Fixes #2164

Type of change

Please check all relevant options.

  • Improvement (non-breaking)
  • Bug fix (non-breaking)
  • New feature (non-breaking)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Testing

Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.

  • Adversarial weight perturbation based training implementation produces results similar to original implementation
  • All functions in implemented code work as expected

Test Configuration:

  • OS: Red Hat Enterprise Linux 8.7 (Ootpa)
  • Python version: 3.9.12
  • ART version or commit number
  • TensorFlow / Keras / PyTorch / MXNet version: PyTorch 1.13.1+cu117

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
@codecov-commenter
Copy link

codecov-commenter commented Jul 20, 2023

Codecov Report

Merging #2224 (0a78cdb) into dev_1.16.0 (19259d7) will increase coverage by 0.09%.
The diff coverage is 89.34%.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Impacted file tree graph

@@              Coverage Diff               @@
##           dev_1.16.0    #2224      +/-   ##
==============================================
+ Coverage       84.76%   84.85%   +0.09%     
==============================================
  Files             313      315       +2     
  Lines           27810    28054     +244     
  Branches         5086     5123      +37     
==============================================
+ Hits            23572    23805     +233     
+ Misses           2948     2941       -7     
- Partials         1290     1308      +18     
Files Changed Coverage Δ
...efences/trainer/adversarial_trainer_awp_pytorch.py 88.07% <88.07%> (ø)
art/defences/trainer/__init__.py 100.00% <100.00%> (ø)
art/defences/trainer/adversarial_trainer_awp.py 100.00% <100.00%> (ø)

... and 6 files with indirect coverage changes

📢 Have feedback on the report? Share it here.

Comment on lines +104 to +110
def fit_generator( # pylint: disable=W0221
self,
generator: DataGenerator,
validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
nb_epochs: int = 20,
**kwargs
):

Check notice

Code scanning / CodeQL

Mismatch between signature and use of an overridden method Note

Overridden method signature does not match
call
, where it is passed too many arguments. Overriding method
method AdversarialTrainerAWPPyTorch.fit_generator
matches the call.
Overridden method signature does not match
call
, where it is passed an argument named 'scheduler'. Overriding method
method AdversarialTrainerAWPPyTorch.fit_generator
matches the call.
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
@beat-buesser beat-buesser self-requested a review July 22, 2023 09:53
@beat-buesser beat-buesser self-assigned this Jul 22, 2023
@beat-buesser beat-buesser added the enhancement New feature or request label Jul 22, 2023
@beat-buesser beat-buesser added this to the ART 1.16.0 milestone Jul 22, 2023
@beat-buesser beat-buesser linked an issue Jul 22, 2023 that may be closed by this pull request
Copy link
Collaborator

@beat-buesser beat-buesser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Zaid-Hameed Thank you very much for your pull request! I have added a few minor comments on using properties to avoid pylint warnings. What do you think?

Have you tested your code on GPUs?

from art.utils import CLASSIFIER_LOSS_GRADIENTS_TYPE


class AdversarialTrainerAWP(Trainer, abc.ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is inheriting from abc.ABC required here? Trainer is already inheriting from abc.ABC.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done by removing abc.ABC.

# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module implements adversarial training with AWP protocol.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's introduce the abbreviation AWP somewhere.

Suggested change
This module implements adversarial training with AWP protocol.
This module implements adversarial training with Adversarial Weight Perturbation (AWP) protocol.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


class AdversarialTrainerAWPPyTorch(AdversarialTrainerAWP):
"""
Class performing adversarial training following AWP protocol.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Class performing adversarial training following AWP protocol.
Class performing adversarial training following Adversarial Weight Perturbation (AWP) protocol.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

import torch

logger = logging.getLogger(__name__)
EPS = 1e-8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the definition of EPS?

Copy link
Collaborator Author

@Zaid-Hameed Zaid-Hameed Aug 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added definition.

Comment on lines +75 to +82
self._classifier: PyTorchClassifier
self._proxy_classifier: PyTorchClassifier
self._attack: EvasionAttack
self._mode: str
self.gamma: float
self._beta: float
self._warmup: int
self._apply_wp: bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these type assignments needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for static type checking by mypy.


params_dict = OrderedDict() # type: ignore
list_params = []
for name, param in p_classifier._model.state_dict().items(): # pylint: disable=W0212
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for name, param in p_classifier._model.state_dict().items(): # pylint: disable=W0212
for name, param in p_classifier.model.state_dict().items():

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self, p_classifier: PyTorchClassifier, list_keys: List[str], w_perturb: Dict[str, "torch.Tensor"], op: str
) -> None:
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a description of the method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added description.

:param w_perturb: dictionary containing model parameters' names as keys and model parameters as values
:param op: controls whether weight perturbation will be added or subtracted from model parameters
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

else:
raise ValueError("Incorrect op provided for weight perturbation. 'op' must be among 'add' and 'subtract'.")
with torch.no_grad():
for name, param in p_classifier._model.named_parameters(): # pylint: disable=W0212
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for name, param in p_classifier._model.named_parameters(): # pylint: disable=W0212
for name, param in p_classifier.model.named_parameters():

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)


# Build a Keras image augmentation object and wrap it in ART
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it Keras?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to correct description.

Muhammad Zaid Hameed and others added 4 commits August 15, 2023 11:48
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
Signed-off-by: Muhammad Zaid Hameed <Zaid.Hameed@ibm.com>
@beat-buesser beat-buesser merged commit 90bf04b into Trusted-AI:dev_1.16.0 Sep 8, 2023
41 of 47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adversarial Weight Perturbation based adversarial Training
3 participants