Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Convergence history #1517

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions core/log/convergence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,55 @@ void Convergence<ValueType>::on_iteration_complete(
const LinOp* residual_norm, const LinOp* implicit_resnorm_sq,
const array<stopping_status>* status, const bool stopped) const
{
auto update_history = [&](auto& container, auto new_val, bool is_norm) {
if (history_ == convergence_history::none) {
if (container.empty()) {
container.emplace_back(nullptr);
}
container.back() = std::move(new_val);
return;
}
if (is_norm || history_ == convergence_history::full) {
container.emplace_back(std::move(new_val));
}
};
if (num_iterations == 0) {
residual_.clear();
residual_norm_.clear();
implicit_sq_resnorm_.clear();
}
if (stopped) {
array<stopping_status> tmp(status->get_executor()->get_master(),
*status);
this->convergence_status_ = true;
convergence_status_ = true;
for (int i = 0; i < status->get_size(); i++) {
if (!tmp.get_data()[i].has_converged()) {
this->convergence_status_ = false;
convergence_status_ = false;
break;
}
}
this->num_iterations_ = num_iterations;
num_iterations_ = num_iterations;
}
if (stopped || history_ != convergence_history::none) {
if (residual != nullptr) {
this->residual_.reset(residual->clone().release());
update_history(residual_, residual->clone(), false);
}
if (implicit_resnorm_sq != nullptr) {
this->implicit_sq_resnorm_.reset(
implicit_resnorm_sq->clone().release());
update_history(implicit_sq_resnorm_, implicit_resnorm_sq->clone(),
true);
}
if (residual_norm != nullptr) {
this->residual_norm_.reset(residual_norm->clone().release());
update_history(residual_norm_, residual_norm->clone(), true);
} else if (residual != nullptr) {
using NormVector = matrix::Dense<remove_complex<ValueType>>;
detail::vector_dispatch<ValueType>(
residual, [&](const auto* dense_r) {
this->residual_norm_ =
update_history(
residual_norm_,
NormVector::create(residual->get_executor(),
dim<2>{1, residual->get_size()[1]});
dense_r->compute_norm2(this->residual_norm_);
dim<2>{1, residual->get_size()[1]}),
true);
dense_r->compute_norm2(residual_norm_.back());
});
} else if (dynamic_cast<const solver::detail::SolverBaseLinOp*>(
solver) &&
Expand All @@ -97,13 +118,21 @@ void Convergence<ValueType>::on_iteration_complete(
detail::vector_dispatch<ValueType>(b, [&](const auto* dense_b) {
detail::vector_dispatch<ValueType>(x, [&](const auto* dense_x) {
auto exec = system_mtx->get_executor();
auto residual = dense_b->clone();
this->residual_norm_ = NormVector::create(
exec, dim<2>{1, residual->get_size()[1]});
update_history(residual_, dense_b->clone(), false);
system_mtx->apply(initialize<Vector>({-1.0}, exec), dense_x,
initialize<Vector>({1.0}, exec),
residual);
residual->compute_norm2(this->residual_norm_);
residual_.back());
update_history(
residual_norm_,
NormVector::create(
exec, dim<2>{1, residual_.back()->get_size()[1]}),
true);
detail::vector_dispatch<ValueType>(
residual_.back().get(),
[&](const auto* actual_residual) {
actual_residual->compute_norm2(
residual_norm_.back());
});
});
});
}
Expand Down
92 changes: 92 additions & 0 deletions core/test/log/convergence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ TYPED_TEST(Convergence, CanGetEmptyData)
ASSERT_EQ(logger->get_residual(), nullptr);
ASSERT_EQ(logger->get_residual_norm(), nullptr);
ASSERT_EQ(logger->get_implicit_sq_resnorm(), nullptr);
ASSERT_TRUE(logger->get_residual_history().empty());
ASSERT_TRUE(logger->get_residual_norm_history().empty());
ASSERT_TRUE(logger->get_implicit_sq_resnorm_history().empty());
}


Expand Down Expand Up @@ -100,6 +103,10 @@ TYPED_TEST(Convergence, DoesNotLogIfNotStopped)
ASSERT_EQ(logger->get_num_iterations(), 0);
ASSERT_EQ(logger->get_residual(), nullptr);
ASSERT_EQ(logger->get_residual_norm(), nullptr);
ASSERT_EQ(logger->get_implicit_sq_resnorm(), nullptr);
ASSERT_TRUE(logger->get_residual_history().empty());
ASSERT_TRUE(logger->get_residual_norm_history().empty());
ASSERT_TRUE(logger->get_implicit_sq_resnorm_history().empty());
}


Expand Down Expand Up @@ -131,4 +138,89 @@ TYPED_TEST(Convergence, CanComputeResidualNormFromSolution)
}


TYPED_TEST(Convergence, CanLogDataWithNormHistory)
{
using AbsoluteDense = gko::matrix::Dense<gko::remove_complex<TypeParam>>;
auto logger = gko::log::Convergence<TypeParam>::create(
gko::convergence_history::norm);

logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 100, nullptr,
this->residual_norm.get(), this->implicit_sq_resnorm.get(), nullptr,
false);
logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 101, nullptr,
this->residual_norm.get(), this->implicit_sq_resnorm.get(),
&this->status, true);

ASSERT_EQ(logger->get_residual_history().size(), 0);
ASSERT_EQ(logger->get_residual_norm_history().size(), 2);
ASSERT_EQ(logger->get_implicit_sq_resnorm_history().size(), 2);
for (int i : {0, 1}) {
GKO_ASSERT_MTX_NEAR(
gko::as<AbsoluteDense>(logger->get_residual_norm_history()[i]),
this->residual_norm, 0);
GKO_ASSERT_MTX_NEAR(gko::as<AbsoluteDense>(
logger->get_implicit_sq_resnorm_history()[i]),
this->implicit_sq_resnorm, 0);
}
}


TYPED_TEST(Convergence, CanLogDataWithFullHistory)
{
using Dense = gko::matrix::Dense<TypeParam>;
using AbsoluteDense = gko::matrix::Dense<gko::remove_complex<TypeParam>>;
auto logger = gko::log::Convergence<TypeParam>::create(
gko::convergence_history::full);

logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 100, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(), nullptr,
false);
logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 101, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(),
&this->status, true);

ASSERT_EQ(logger->get_residual_history().size(), 2);
ASSERT_EQ(logger->get_residual_norm_history().size(), 2);
ASSERT_EQ(logger->get_implicit_sq_resnorm_history().size(), 2);
for (int i : {0, 1}) {
GKO_ASSERT_MTX_NEAR(gko::as<Dense>(logger->get_residual_history()[i]),
this->residual, 0);
GKO_ASSERT_MTX_NEAR(
gko::as<AbsoluteDense>(logger->get_residual_norm_history()[i]),
this->residual_norm, 0);
GKO_ASSERT_MTX_NEAR(gko::as<AbsoluteDense>(
logger->get_implicit_sq_resnorm_history()[i]),
this->implicit_sq_resnorm, 0);
}
}


TYPED_TEST(Convergence, CanClearHistory)
{
auto logger = gko::log::Convergence<TypeParam>::create(
gko::convergence_history::full);

logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 100, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(), nullptr,
false);
logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 101, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(),
&this->status, true);
logger->template on<gko::log::Logger::iteration_complete>(
this->system.get(), nullptr, nullptr, 0, this->residual.get(),
this->residual_norm.get(), this->implicit_sq_resnorm.get(), nullptr,
false);

ASSERT_EQ(logger->get_residual_history().size(), 1);
ASSERT_EQ(logger->get_residual_norm_history().size(), 1);
ASSERT_EQ(logger->get_implicit_sq_resnorm_history().size(), 1);
}


} // namespace
71 changes: 42 additions & 29 deletions include/ginkgo/core/base/utils_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,35 @@ inline typename std::enable_if<!detail::have_ownership_s<Pointer>::value,
}


/**
* This is a deleter that does not delete the object.
*
* It is useful where the object has been allocated elsewhere and will be
* deleted manually.
*/
template <typename T>
class null_deleter {
public:
using pointer = T*;

/**
* Deletes the object.
*
* @param ptr pointer to the object being deleted
*/
void operator()(pointer) const noexcept {}
};

// a specialization for arrays
template <typename T>
class null_deleter<T[]> {
public:
using pointer = T[];

void operator()(pointer) const noexcept {}
};


/**
* Performs polymorphic type conversion.
*
Expand Down Expand Up @@ -406,6 +435,19 @@ inline std::unique_ptr<std::decay_t<T>> as(std::unique_ptr<U>&& obj)
}
}

template <typename T, typename U>
inline std::unique_ptr<const std::decay_t<T>,
null_deleter<const std::decay_t<T>>>
as(const std::unique_ptr<U>& obj)
{
if (auto p = dynamic_cast<const std::decay_t<T>*>(obj.get())) {
return {p, null_deleter<const std::decay_t<T>>{}};
} else {
throw NotSupported(__FILE__, __LINE__, __func__,
name_demangling::get_type_name(typeid(*obj)));
}
}


/**
* Performs polymorphic type conversion of a shared_ptr.
Expand Down Expand Up @@ -457,35 +499,6 @@ inline std::shared_ptr<const std::decay_t<T>> as(std::shared_ptr<const U> obj)
}


/**
* This is a deleter that does not delete the object.
*
* It is useful where the object has been allocated elsewhere and will be
* deleted manually.
*/
template <typename T>
class null_deleter {
public:
using pointer = T*;

/**
* Deletes the object.
*
* @param ptr pointer to the object being deleted
*/
void operator()(pointer) const noexcept {}
};

// a specialization for arrays
template <typename T>
class null_deleter<T[]> {
public:
using pointer = T[];

void operator()(pointer) const noexcept {}
};


} // namespace gko


Expand Down
Loading
Loading