Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix, formatting code and removing old TODOs in future/torch #1140

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
67 changes: 34 additions & 33 deletions cleverhans/future/torch/attacks/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,37 @@
import torch


def noise(x, eps=0.3, order=np.inf, clip_min=None, clip_max=None):
"""
A weak attack that just picks a random point in the attacker's action
space. When combined with an attack bundling function, this can be used to
implement random search.

References:
https://arxiv.org/abs/1802.00420 recommends random search to help identify
gradient masking

https://openreview.net/forum?id=H1g0piA9tQ recommends using noise as part
of an attack building recipe combining many different optimizers to
yield a strong optimizer.

Args:
:param x: the input tensor
:param eps: (optional float) maximum distortion of adversarial example
compared to original input.
:param norm: (optional) Order of the norm.
:param clip_min: (optional float) Minimum input component value
:param clip_max: (optional float) Maximum input component value
"""

if order != np.inf: raise NotImplementedError(norm)

eta = torch.FloatTensor(*x.shape).to(x.device).uniform_(-eps, eps)
adv_x = x + eta

if clip_min is not None or clip_max is not None:
assert clip_min is not None and clip_max is not None
adv_x = torch.clamp(adv_x, min=clip_min, max=clip_max)

return adv_x
def noise(x, eps=0.3, norm=np.inf, clip_min=None, clip_max=None):
"""
A weak attack that just picks a random point in the attacker's action
space. When combined with an attack bundling function, this can be used to
implement random search.

References:
https://arxiv.org/abs/1802.00420 recommends random search to help identify
gradient masking

https://openreview.net/forum?id=H1g0piA9tQ recommends using noise as part
of an attack building recipe combining many different optimizers to
yield a strong optimizer.

Args:
:param x: the input tensor
:param eps: (optional float) maximum distortion of adversarial example
compared to original input.
:param norm: (optional) Order of the norm.
:param clip_min: (optional float) Minimum input component value
:param clip_max: (optional float) Maximum input component value
"""

if norm != np.inf:
raise ValueError("Norm order must be np.inf, got {} instead.".format(norm))

eta = torch.FloatTensor(*x.shape).to(x.device).uniform_(-eps, eps)
adv_x = x + eta

if clip_min is not None or clip_max is not None:
assert clip_min is not None and clip_max is not None
adv_x = torch.clamp(adv_x, min=clip_min, max=clip_max)

return adv_x
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def projected_gradient_descent(model_fn, x, eps, eps_iter, nb_iter, norm,

asserts.append(eps_iter <= eps)
if norm == np.inf and clip_min is not None:
# TODO necessary to cast clip_min and clip_max to x.dtype?
asserts.append(eps + clip_min <= clip_max)

if sanity_checks:
Expand Down
15 changes: 2 additions & 13 deletions cleverhans/future/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def clip_eta(eta, norm, eps):
torch.tensor(1., dtype=eta.dtype, device=eta.device),
eps / norm
)
eta *= factor
eta = eta * factor
return eta


def get_or_guess_labels(model, x, **kwargs):
"""
Get the label to use in generating an adversarial example for x.
Expand Down Expand Up @@ -84,7 +85,6 @@ def optimize_linear(grad, eps, norm=np.inf):
# Take sign of gradient
optimal_perturbation = torch.sign(grad)
elif norm == 1:
abs_grad = torch.abs(grad)
sign = torch.sign(grad)
red_ind = list(range(1, len(grad.size())))
abs_grad = torch.abs(grad)
Expand All @@ -97,23 +97,12 @@ def optimize_linear(grad, eps, norm=np.inf):
for red_scalar in red_ind:
num_ties = torch.sum(num_ties, red_scalar, keepdim=True)
optimal_perturbation = sign * max_mask / num_ties
# TODO integrate below to a test file
# check that the optimal perturbations have been correctly computed
opt_pert_norm = optimal_perturbation.abs().sum(dim=red_ind)
assert torch.all(opt_pert_norm == torch.ones_like(opt_pert_norm))
elif norm == 2:
square = torch.max(
avoid_zero_div,
torch.sum(grad ** 2, red_ind, keepdim=True)
)
optimal_perturbation = grad / torch.sqrt(square)
# TODO integrate below to a test file
# check that the optimal perturbations have been correctly computed
opt_pert_norm = optimal_perturbation.pow(2).sum(dim=red_ind, keepdim=True).sqrt()
one_mask = (
(square <= avoid_zero_div).to(torch.float) * opt_pert_norm +
(square > avoid_zero_div).to(torch.float))
assert torch.allclose(opt_pert_norm, one_mask, rtol=1e-05, atol=1e-08)
else:
raise NotImplementedError("Only L-inf, L1 and L2 norms are "
"currently implemented.")
Expand Down
2 changes: 2 additions & 0 deletions tutorials/future/torch/cifar10_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def main(_):
if FLAGS.adv_train:
# Replace clean example with adversarial example for adversarial training
x = projected_gradient_descent(net, x, FLAGS.eps, 0.01, 40, np.inf)
# Stop backward from entering the graph that created the adv example
x = x.clone().detach()
optimizer.zero_grad()
loss = loss_fn(net(x), y)
loss.backward()
Expand Down