Skip to content

Commit

Permalink
feat(torch): add ADAMP variant of adam in RANGER (2006.08217)
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and mergify[bot] committed Jun 24, 2021
1 parent ce07602 commit e26ed77
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ clip_norm | real | yes | 100.0 | gradients with euclid
rectified | bool | yes | false | rectified momentum variance ie https://arxiv.org/abs/1908.03265 valid for ADAM[W] and AMSGRAD[W]
adabelief | bool | yes | false | adabelief mod for ADAM https://arxiv.org/abs/2010.07468
gradient_centralization | bool | yes | false | centralized gradient mod for ADAM ie https://arxiv.org/abs/2004.01461v2
adamp | bool | yes | false | enable ADAMP version https://arxiv.org/abs/2006.08217
test_interval | int | yes | N/A | Number of iterations between testing phases
test_initialization | bool | true | N/A | Whether to start training by testing the network
lr_policy | string | yes | N/A | learning rate policy ("step", "inv", "fixed", "sgdr", ...)
Expand Down
104 changes: 97 additions & 7 deletions src/backends/torch/optim/ranger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include "./ranger.h"
#include <torch/csrc/autograd/variable.h>
#include <torch/nn/module.h>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include <torch/nn/functional.h>
#pragma GCC diagnostic pop
#include <torch/serialize/archive.h>
#include <torch/utils.h>

Expand Down Expand Up @@ -52,8 +56,8 @@ namespace dd
&& (lhs.lookahead() == rhs.lookahead())
&& (lhs.adabelief() == rhs.adabelief())
&& (lhs.gradient_centralization() == rhs.gradient_centralization())
&& (lhs.lsteps() == rhs.lsteps()) && (lhs.lalpha() == rhs.lalpha())
&& (lhs.swa() == rhs.swa());
&& (lhs.adamp() == rhs.adamp()) && (lhs.lsteps() == rhs.lsteps())
&& (lhs.lalpha() == rhs.lalpha()) && (lhs.swa() == rhs.swa());
}

void RangerOptions::serialize(torch::serialize::OutputArchive &archive) const
Expand All @@ -67,6 +71,7 @@ namespace dd
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lookahead);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(adabelief);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(gradient_centralization);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(adamp);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lsteps);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lalpha);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(swa);
Expand All @@ -83,6 +88,7 @@ namespace dd
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, lookahead);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, adabelief);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, gradient_centralization);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, adamp);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int, lsteps);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lalpha);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, swa);
Expand Down Expand Up @@ -116,6 +122,48 @@ namespace dd
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, swa_buffer);
}

float Ranger::projection(torch::Tensor p, torch::Tensor grad,
torch::Tensor perturb, float eps, float delta,
float wd)
{
std::vector<long int> expand_size(p.sizes().size(), 1);
expand_size[0] = -1;

std::vector<std::function<torch::Tensor(torch::Tensor)>> view_funcs{
[](torch::Tensor x) {
return x.view({ x.size(0), -1L });
},
[](torch::Tensor x) {
return x.view({ 1L, -1L });
}
};

for (auto view_func : view_funcs)
{
torch::Tensor p_data_view = view_func(p.data());
torch::Tensor cosine_similarity
= torch::nn::functional::cosine_similarity(
view_func(grad), p_data_view,
torch::nn::functional::CosineSimilarityFuncOptions()
.dim(1)
.eps(eps))
.abs_();

if (cosine_similarity.max().item<double>()
< delta / std::sqrt(p_data_view.size(1)))
{
torch::Tensor p_n
= p.data()
/ p_data_view.norm({ 1 }).view(expand_size).add_(eps);
perturb
-= p_n * view_func(p_n * perturb).sum({ 1 }).view(expand_size);
return wd;
}
}

return 1.0;
}

torch::Tensor Ranger::step(LossClosure closure)
{
torch::NoGradGuard no_grad;
Expand Down Expand Up @@ -175,7 +223,8 @@ namespace dd
auto bias_correction1 = 1.0 - std::pow(beta1, state.step());
auto bias_correction2 = 1.0 - std::pow(beta2, state.step());

if (options.weight_decay() != 0) // weight decay not decoupled !!
if (options.weight_decay() != 0
&& !options.adamp()) // weight decay not decoupled !!
grad = grad.add(p, options.weight_decay());

if (options.gradient_centralization())
Expand All @@ -201,7 +250,19 @@ namespace dd
.add_(options.eps());

auto step_size = options.lr() / bias_correction1;
p.addcdiv_(exp_avg, denom, -step_size);
auto perturb = exp_avg / denom;
// below adamp
if (options.adamp())
{
double wd_ratio = 1;
if (p.sizes().size() > 1)
wd_ratio = projection(p, grad, perturb, options.eps());
if (options.weight_decay() != 0)
p.data().mul_(1.0
- options.lr() * options.weight_decay()
* wd_ratio);
}
p.add_(perturb, -step_size);
}
else
{
Expand All @@ -223,12 +284,41 @@ namespace dd
.add_(options.eps());
else
denom = exp_avg_sq.sqrt().add_(options.eps());
p.addcdiv_(exp_avg, denom, -step_size * options.lr());
auto perturb = exp_avg / denom;
step_size *= options.lr();
// below adamp
if (options.adamp())
{
double wd_ratio = 1;
if (p.sizes().size() > 1)
wd_ratio
= projection(p, grad, perturb, options.eps());
if (options.weight_decay() != 0)
p.data().mul_(1.0
- options.lr() * options.weight_decay()
* wd_ratio);
}

p.add_(perturb, -step_size);
}
else
{
step_size = 1.0 / bias_correction1;
p.add_(exp_avg, -step_size * options.lr());
step_size = options.lr() / bias_correction1;
auto perturb = exp_avg;
// below adamp
if (options.adamp())
{
double wd_ratio = 1;
if (p.sizes().size() > 1)
wd_ratio
= projection(p, grad, perturb, options.eps());
if (options.weight_decay() != 0)
p.data().mul_(1.0
- options.lr() * options.weight_decay()
* wd_ratio);
}

p.add_(perturb, -step_size);
}
}
if (state.step() % options.lsteps() == 0 && options.lookahead())
Expand Down
5 changes: 5 additions & 0 deletions src/backends/torch/optim/ranger.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ namespace dd
TORCH_ARG(bool, lookahead) = true;
TORCH_ARG(bool, adabelief) = false;
TORCH_ARG(bool, gradient_centralization) = false;
TORCH_ARG(bool, adamp) = false;
TORCH_ARG(int, lsteps) = 6;
TORCH_ARG(double, lalpha) = 0.5;
TORCH_ARG(bool, swa) = false;
Expand Down Expand Up @@ -125,6 +126,10 @@ namespace dd

void swap_swa_sgd();

float projection(torch::Tensor p, torch::Tensor grad,
torch::Tensor perturb, float eps, float delta = 0.1,
float wd = 0.1); // delta and wd_ratio from paper

private:
template <typename Self, typename Archive>
static void serialize(Self &self, Archive &archive)
Expand Down
10 changes: 10 additions & 0 deletions src/backends/torch/torchsolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ namespace dd
_adabelief = ad_solver.get("adabelief").get<bool>();
if (ad_solver.has("gradient_centralization"))
_gc = ad_solver.get("gradient_centralization").get<bool>();
if (ad_solver.has("adamp"))
_adamp = ad_solver.get("adamp").get<bool>();
if (ad_solver.has("lookahead_steps"))
_lsteps = ad_solver.get("lookahead_steps").get<int>();
if (ad_solver.has("lookahead_alpha"))
Expand Down Expand Up @@ -126,6 +128,7 @@ namespace dd
.lookahead(_lookahead)
.adabelief(_adabelief)
.gradient_centralization(_gc)
.adamp(_adamp)
.lsteps(_lsteps)
.lalpha(_lalpha)));
this->_logger->info("base_lr: {}", _base_lr);
Expand All @@ -136,6 +139,12 @@ namespace dd
this->_logger->info("lookahead: {}", _lookahead);
this->_logger->info("adabelief: {}", _adabelief);
this->_logger->info("gradient_centralization: {}", _gc);
this->_logger->info("adamp: {}", _adamp);
if (_adamp && _adabelief)
this->_logger->warn(
"both adabelief and adamp seletected, preliminary tests show "
"that adamp works better w/o adabelief, please double check "
"your parameters");
if (_lookahead)
{
this->_logger->info("lookahead steps: {}", _lsteps);
Expand Down Expand Up @@ -395,6 +404,7 @@ namespace dd
options.lookahead(_lookahead);
options.adabelief(_adabelief);
options.gradient_centralization(_gc);
options.adamp(_adamp);
options.lsteps(_lsteps);
options.lalpha(_lalpha);
}
Expand Down
1 change: 1 addition & 0 deletions src/backends/torch/torchsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ namespace dd
bool _lookahead = true; /**< for RANGER : use hinton's lookahead */
bool _adabelief = false; /**< for RANGER : use ADABELIEF version */
bool _gc = false; /**< for RANGER : use gradient centralization */
bool _adamp = false; /**< for RANGER : ADAMP variant */
int _lsteps
= 5; /**< for RANGER, if lookahead: number of lookahead steps */
double _lalpha = 0.5; /**< for RANGER, if lookahead: weight of lookahead */
Expand Down
3 changes: 2 additions & 1 deletion tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2667,7 +2667,8 @@ TEST(torchapi, service_train_ranger)
+ iterations_resnet50 + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":4,\"solver_type\":\"RANGER\","
"\"lookahead\":true,\"rectified\":true,\"adabelief\":true,"
"\"gradient_centralization\":true,\"clip\":false,\"test_interval\":"
"\"gradient_centralization\":true,\"adamp\":true,\"clip\":false,"
"\"test_interval\":"
+ iterations_resnet50
+ "},\"net\":{\"batch_size\":4},\"nclasses\":2,\"resume\":false},"
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true,\"test_"
Expand Down

0 comments on commit e26ed77

Please sign in to comment.