Skip to content

Commit

Permalink
feat(torch): add radam optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and mergify[bot] committed Jan 12, 2023
1 parent 7340a48 commit 5bba045
Show file tree
Hide file tree
Showing 5 changed files with 425 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ if (USE_TORCH)
backends/torch/torchutils.cc
backends/torch/optim/ranger.cc
backends/torch/optim/madgrad.cc
backends/torch/optim/radam.cc
backends/torch/torchdataaug.cc
)
endif()
Expand Down
211 changes: 211 additions & 0 deletions src/backends/torch/optim/radam.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/**
* DeepDetect
* Copyright (c) 2019-2023 Jolibrain
* Author: Guillaume Infantes <guillaume.infantes@jolibrain.com>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

/*this is largely inspired/adapted from adam torch/c++ implementation, ie
* pytorch/torch/csrc/api/src/optim/adam.cpp */

#include "./radam.h"
#include "mllibstrategy.h"

#include <torch/csrc/autograd/variable.h>
#include <torch/nn/module.h>
#include <torch/serialize/archive.h>
#include <torch/utils.h>

#include <ATen/ATen.h>
#include <c10/util/irange.h>

#include <cmath>
#include <functional>

namespace dd
{
RAdamOptions::RAdamOptions(double lr) : lr_(lr)
{
}

bool operator==(const RAdamOptions &lhs, const RAdamOptions &rhs)
{
return (lhs.lr() == rhs.lr())
&& (std::get<0>(lhs.betas()) == std::get<0>(rhs.betas()))
&& (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas()))
&& (lhs.eps() == rhs.eps())
&& (lhs.weight_decay() == rhs.weight_decay());
}

void RAdamOptions::serialize(torch::serialize::OutputArchive &archive) const
{
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(betas);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
}

void RAdamOptions::serialize(torch::serialize::InputArchive &archive)
{
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(betas_t, betas);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
}

double RAdamOptions::get_lr() const
{
return lr();
}

void RAdamOptions::set_lr(const double lr)
{
this->lr(lr);
}

bool operator==(const RAdamParamState &lhs, const RAdamParamState &rhs)
{
return (lhs.step() == rhs.step())
&& torch::equal(lhs.exp_avg(), rhs.exp_avg())
&& torch::equal(lhs.exp_avg_sq(), rhs.exp_avg_sq());
}

void
RAdamParamState::serialize(torch::serialize::OutputArchive &archive) const
{
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(step);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(exp_avg_sq);
}

void RAdamParamState::serialize(torch::serialize::InputArchive &archive)
{
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, step);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, exp_avg);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(torch::Tensor, exp_avg_sq);
}

torch::Tensor RAdam::step(LossClosure closure)
{
torch::NoGradGuard no_grad;
torch::Tensor loss = {};
if (closure != nullptr)
{
at::AutoGradMode enable_grad(true);
loss = closure();
}
for (auto &group : param_groups_)
{
for (auto &p : group.params())
{
if (!p.grad().defined())
{
continue;
}
auto grad = p.grad();
TORCH_CHECK(
!grad.is_sparse(), "RAdam does not support sparse gradients" /*, please consider SparseRAdam instead*/);
auto param_state
= state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
auto &options = static_cast<RAdamOptions &>(group.options());

// State initialization
if (param_state == state_.end())
{
auto state = std::make_unique<RAdamParamState>();
state->step(0);
// Exponential moving average of gradient values
state->exp_avg(
torch::zeros_like(p, torch::MemoryFormat::Preserve));
// Exponential moving average of squared gradient values
state->exp_avg_sq(
torch::zeros_like(p, torch::MemoryFormat::Preserve));
state_[c10::guts::to_string(p.unsafeGetTensorImpl())]
= std::move(state);
}

auto &state = static_cast<RAdamParamState &>(
*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
auto &exp_avg = state.exp_avg();
auto &exp_avg_sq = state.exp_avg_sq();

state.step(state.step() + 1);
auto beta1 = std::get<0>(options.betas());
auto beta2 = std::get<1>(options.betas());

auto bias_correction1 = 1 - std::pow(beta1, state.step());
auto bias_correction2 = 1 - std::pow(beta2, state.step());

if (options.weight_decay() != 0)
{
grad = grad.add(p, options.weight_decay());
}

// Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, 1 - beta1);
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, 1 - beta2);

auto N_sma_max = 2.0 / (1.0 - beta2) - 1.0; // rho_inf
auto beta2_t = std::pow(beta2, (float)state.step());
auto N_sma
= N_sma_max
- 2.0 * (float)state.step() * beta2_t / bias_correction2;
if (N_sma >= 5.0)
{
auto step_size
= sqrt((1.0 - beta2_t) * (N_sma - 4.0) * (N_sma - 2.0)
* N_sma_max / (N_sma_max - 4.0) / (N_sma_max - 2.0)
/ N_sma)
/ bias_correction1;
torch::Tensor denom = exp_avg_sq.add_(options.eps()).sqrt();
auto perturb = exp_avg / denom;
step_size *= options.lr();
p.add_(perturb, -step_size);
}
else
{
auto step_size = options.lr() / bias_correction1;
auto perturb = exp_avg;
p.add_(perturb, -step_size);
}
}
}
return loss;
}

void RAdam::save(torch::serialize::OutputArchive &archive) const
{
serialize(*this, archive);
}

void RAdam::load(torch::serialize::InputArchive &archive)
{
torch::IValue pytorch_version;
if (archive.try_read("pytorch_version", pytorch_version))
{
serialize(*this, archive);
}
else
{ // deserializing archives saved in old format (prior to
// version 1.5.0)
TORCH_WARN("Your serialized RAdam optimizer is still using the old "
"serialization format. "
"You should re-save your RAdam optimizer to use the new "
"serialization format.");
}
}
} // namespace dd
120 changes: 120 additions & 0 deletions src/backends/torch/optim/radam.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/**
* DeepDetect
* Copyright (c) 2019-2023 Jolibrain
* Author: Guillaume Infantes <guillaume.infantes@jolibrain.com>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/
#pragma once

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include <torch/arg.h>
#include <torch/nn/module.h>
#include <torch/optim/optimizer.h>
#include <torch/optim/serialize.h>
#pragma GCC diagnostic pop

#include <utility>
#include <vector>

namespace torch
{
namespace serialize
{
class OutputArchive;
class InputArchive;
} // namespace serialize
} // namespace torch

namespace dd
{
struct TORCH_API RAdamOptions
: public torch::optim::OptimizerCloneableOptions<RAdamOptions>
{
RAdamOptions(double lr = 1e-3);
TORCH_ARG(double, lr) = 1e-3;
typedef std::tuple<double, double> betas_t;
TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999);
TORCH_ARG(double, eps) = 1e-8;
TORCH_ARG(double, weight_decay) = 0;

public:
void serialize(torch::serialize::InputArchive &archive) override;
void serialize(torch::serialize::OutputArchive &archive) const override;
TORCH_API friend bool operator==(const RAdamOptions &lhs,
const RAdamOptions &rhs);
~RAdamOptions() override = default;
double get_lr() const override;
void set_lr(const double lr) override;
};

struct TORCH_API RAdamParamState
: public torch::optim::OptimizerCloneableParamState<RAdamParamState>
{
TORCH_ARG(int64_t, step) = 0;
TORCH_ARG(torch::Tensor, exp_avg);
TORCH_ARG(torch::Tensor, exp_avg_sq);
TORCH_ARG(torch::Tensor, max_exp_avg_sq) = {};

public:
void serialize(torch::serialize::InputArchive &archive) override;
void serialize(torch::serialize::OutputArchive &archive) const override;
TORCH_API friend bool operator==(const RAdamParamState &lhs,
const RAdamParamState &rhs);
~RAdamParamState() override = default;
};

class TORCH_API RAdam : public torch::optim::Optimizer
{
public:
explicit RAdam(std::vector<torch::optim::OptimizerParamGroup> param_groups,
RAdamOptions defaults = {})
: Optimizer(std::move(param_groups),
std::make_unique<RAdamOptions>(defaults))
{
TORCH_CHECK(defaults.lr() >= 0,
"Invalid learning rate: ", defaults.lr());
TORCH_CHECK(defaults.eps() >= 0,
"Invalid epsilon value: ", defaults.eps());
auto betas = defaults.betas();
TORCH_CHECK(0 <= std::get<0>(betas) && std::get<0>(betas) < 1.0,
"Invalid beta parameter at index 0: ", std::get<0>(betas));
TORCH_CHECK(0 <= std::get<1>(betas) && std::get<1>(betas) < 1.0,
"Invalid beta parameter at index 1: ", std::get<1>(betas));
TORCH_CHECK(defaults.weight_decay() >= 0,
"Invalid weight_decay value: ", defaults.weight_decay());
}
explicit RAdam(std::vector<torch::Tensor> params,
// NOLINTNEXTLINE(performance-move-const-arg)
RAdamOptions defaults = {})
: RAdam({ std::move(torch::optim::OptimizerParamGroup(params)) },
defaults)
{
}

torch::Tensor step(LossClosure closure = nullptr) override;
void save(torch::serialize::OutputArchive &archive) const override;
void load(torch::serialize::InputArchive &archive) override;

private:
template <typename Self, typename Archive>
static void serialize(Self &self, Archive &archive)
{
_TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(RAdam);
}
};
} // namespace dd
12 changes: 12 additions & 0 deletions src/backends/torch/torchsolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "torchsolver.h"
#include "optim/ranger.h"
#include "optim/radam.h"
#include "optim/madgrad.h"

namespace dd
Expand Down Expand Up @@ -98,6 +99,17 @@ namespace dd
.betas(std::make_tuple(_beta1, _beta2))
.weight_decay(_weight_decay)));
}
else if (_solver_type == "RADAM")
{
_optimizer = std::unique_ptr<torch::optim::Optimizer>(
new RAdam(_params, RAdamOptions(_base_lr)
.betas(std::make_tuple(_beta1, _beta2))
.weight_decay(_weight_decay)));
this->_logger->info("beta_1: {}", _beta1);
this->_logger->info("beta_2: {}", _beta2);
this->_logger->info("weight_decay: {}", _weight_decay);
this->_logger->info("rectified: TRUE");
}
else if (_solver_type == "ADAMW")
{
if (_weight_decay == 0.0)
Expand Down
Loading

0 comments on commit 5bba045

Please sign in to comment.