Skip to content

Commit

Permalink
[prec] add gauss seidel parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Jul 10, 2024
1 parent a94fed5 commit f8cbc05
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 29 deletions.
1 change: 1 addition & 0 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ enum class LinOpFactoryType : int {
ParIct,
ParIlu,
ParIlut,
GaussSeidel,
Ic,
Ilu,
Isai,
Expand Down
2 changes: 2 additions & 0 deletions core/config/preconditioner_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/preconditioner/gauss_seidel.hpp>
#include <ginkgo/core/preconditioner/ic.hpp>
#include <ginkgo/core/preconditioner/ilu.hpp>
#include <ginkgo/core/preconditioner/isai.hpp>
Expand Down Expand Up @@ -294,6 +295,7 @@ deferred_factory_parameter<gko::LinOpFactory> parse<LinOpFactoryType::Isai>(
}


GKO_PARSE_VALUE_AND_INDEX_TYPE(GaussSeidel, gko::preconditioner::GaussSeidel);
GKO_PARSE_VALUE_AND_INDEX_TYPE(Jacobi, gko::preconditioner::Jacobi);
GKO_PARSE_VALUE_AND_INDEX_TYPE(Sor, gko::preconditioner::Sor);

Expand Down
58 changes: 30 additions & 28 deletions core/config/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,36 @@ namespace config {

configuration_map generate_config_map()
{
return {{"solver::Cg", parse<LinOpFactoryType::Cg>},
{"solver::Bicg", parse<LinOpFactoryType::Bicg>},
{"solver::Bicgstab", parse<LinOpFactoryType::Bicgstab>},
{"solver::Fcg", parse<LinOpFactoryType::Fcg>},
{"solver::Cgs", parse<LinOpFactoryType::Cgs>},
{"solver::Ir", parse<LinOpFactoryType::Ir>},
{"solver::Idr", parse<LinOpFactoryType::Idr>},
{"solver::Gcr", parse<LinOpFactoryType::Gcr>},
{"solver::Gmres", parse<LinOpFactoryType::Gmres>},
{"solver::CbGmres", parse<LinOpFactoryType::CbGmres>},
{"solver::Direct", parse<LinOpFactoryType::Direct>},
{"solver::LowerTrs", parse<LinOpFactoryType::LowerTrs>},
{"solver::UpperTrs", parse<LinOpFactoryType::UpperTrs>},
{"factorization::Ic", parse<LinOpFactoryType::Factorization_Ic>},
{"factorization::Ilu", parse<LinOpFactoryType::Factorization_Ilu>},
{"factorization::Cholesky", parse<LinOpFactoryType::Cholesky>},
{"factorization::Lu", parse<LinOpFactoryType::Lu>},
{"factorization::ParIc", parse<LinOpFactoryType::ParIc>},
{"factorization::ParIct", parse<LinOpFactoryType::ParIct>},
{"factorization::ParIlu", parse<LinOpFactoryType::ParIlu>},
{"factorization::ParIlut", parse<LinOpFactoryType::ParIlut>},
{"preconditioner::Ic", parse<LinOpFactoryType::Ic>},
{"preconditioner::Ilu", parse<LinOpFactoryType::Ilu>},
{"preconditioner::Isai", parse<LinOpFactoryType::Isai>},
{"preconditioner::Jacobi", parse<LinOpFactoryType::Jacobi>},
{"preconditioner::Sor", parse<LinOpFactoryType::Sor>},
{"solver::Multigrid", parse<LinOpFactoryType::Multigrid>},
{"multigrid::Pgm", parse<LinOpFactoryType::Pgm>}};
return {
{"solver::Cg", parse<LinOpFactoryType::Cg>},
{"solver::Bicg", parse<LinOpFactoryType::Bicg>},
{"solver::Bicgstab", parse<LinOpFactoryType::Bicgstab>},
{"solver::Fcg", parse<LinOpFactoryType::Fcg>},
{"solver::Cgs", parse<LinOpFactoryType::Cgs>},
{"solver::Ir", parse<LinOpFactoryType::Ir>},
{"solver::Idr", parse<LinOpFactoryType::Idr>},
{"solver::Gcr", parse<LinOpFactoryType::Gcr>},
{"solver::Gmres", parse<LinOpFactoryType::Gmres>},
{"solver::CbGmres", parse<LinOpFactoryType::CbGmres>},
{"solver::Direct", parse<LinOpFactoryType::Direct>},
{"solver::LowerTrs", parse<LinOpFactoryType::LowerTrs>},
{"solver::UpperTrs", parse<LinOpFactoryType::UpperTrs>},
{"factorization::Ic", parse<LinOpFactoryType::Factorization_Ic>},
{"factorization::Ilu", parse<LinOpFactoryType::Factorization_Ilu>},
{"factorization::Cholesky", parse<LinOpFactoryType::Cholesky>},
{"factorization::Lu", parse<LinOpFactoryType::Lu>},
{"factorization::ParIc", parse<LinOpFactoryType::ParIc>},
{"factorization::ParIct", parse<LinOpFactoryType::ParIct>},
{"factorization::ParIlu", parse<LinOpFactoryType::ParIlu>},
{"factorization::ParIlut", parse<LinOpFactoryType::ParIlut>},
{"preconditioner::GaussSeidel", parse<LinOpFactoryType::GaussSeidel>},
{"preconditioner::Ic", parse<LinOpFactoryType::Ic>},
{"preconditioner::Ilu", parse<LinOpFactoryType::Ilu>},
{"preconditioner::Isai", parse<LinOpFactoryType::Isai>},
{"preconditioner::Jacobi", parse<LinOpFactoryType::Jacobi>},
{"preconditioner::Sor", parse<LinOpFactoryType::Sor>},
{"solver::Multigrid", parse<LinOpFactoryType::Multigrid>},
{"multigrid::Pgm", parse<LinOpFactoryType::Pgm>}};
}


Expand Down
31 changes: 31 additions & 0 deletions core/preconditioner/gauss_seidel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,42 @@
#include <ginkgo/core/preconditioner/gauss_seidel.hpp>
#include <ginkgo/core/preconditioner/sor.hpp>

#include "core/config/config_helper.hpp"


namespace gko {
namespace preconditioner {


template <typename ValueType, typename IndexType>
typename GaussSeidel<ValueType, IndexType>::parameters_type
GaussSeidel<ValueType, IndexType>::parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto params = GaussSeidel::build();

if (auto& obj = config.get("skip_sorting")) {
params.with_skip_sorting(config::get_value<bool>(obj));
}
if (auto& obj = config.get("symmetric")) {
params.with_symmetric(config::get_value<bool>(obj));
}
if (auto& obj = config.get("l_solver")) {
params.with_l_solver(
gko::config::parse_or_get_factory<const LinOpFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("u_solver")) {
params.with_u_solver(
gko::config::parse_or_get_factory<const LinOpFactory>(
obj, context, td_for_child));
}

return params;
}


template <typename ValueType, typename IndexType>
std::unique_ptr<typename GaussSeidel<ValueType, IndexType>::composition_type>
GaussSeidel<ValueType, IndexType>::generate(
Expand Down
49 changes: 48 additions & 1 deletion core/test/config/preconditioner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/preconditioner/gauss_seidel.hpp>
#include <ginkgo/core/preconditioner/ic.hpp>
#include <ginkgo/core/preconditioner/ilu.hpp>
#include <ginkgo/core/preconditioner/isai.hpp>
Expand Down Expand Up @@ -347,6 +348,52 @@ struct Sor
};


struct GaussSeidel
: PreconditionerConfigTest<
::gko::preconditioner::GaussSeidel<float, gko::int32>,
::gko::preconditioner::GaussSeidel<double, gko::int32>> {
using Ir = gko::solver::Ir<float>;

static pnode::map_type setup_base()
{
return {{"type", pnode{"preconditioner::GaussSeidel"}}};
}

static void change_template(pnode::map_type& config_map)
{
config_map["value_type"] = pnode{"float32"};
}

template <bool from_reg, typename ParamType>
static void set(pnode::map_type& config_map, ParamType& param, registry reg,
std::shared_ptr<const gko::Executor> exec)
{
config_map["skip_sorting"] = pnode{true};
param.with_skip_sorting(true);
config_map["symmetric"] = pnode{true};
param.with_symmetric(true);
config_map["l_solver"] = pnode{
{{"type", pnode{"solver::Ir"}}, {"value_type", pnode{"float32"}}}};
param.with_l_solver(DummyIr::build());
config_map["u_solver"] = pnode{
{{"type", pnode{"solver::Ir"}}, {"value_type", pnode{"float32"}}}};
param.with_u_solver(DummyIr::build());
}

template <bool from_reg, typename AnswerType>
static void validate(gko::LinOpFactory* result, AnswerType* answer)
{
auto res_param = gko::as<AnswerType>(result)->get_parameters();
auto ans_param = answer->get_parameters();

ASSERT_EQ(res_param.skip_sorting, ans_param.skip_sorting);
ASSERT_EQ(res_param.symmetric, ans_param.symmetric);
ASSERT_EQ(typeid(res_param.l_solver), typeid(ans_param.l_solver));
ASSERT_EQ(typeid(res_param.u_solver), typeid(ans_param.u_solver));
}
};


template <typename T>
class Preconditioner : public ::testing::Test {
protected:
Expand Down Expand Up @@ -378,7 +425,7 @@ class Preconditioner : public ::testing::Test {


using PreconditionerTypes =
::testing::Types<::Ic, ::Ilu, ::Isai, ::Jacobi, ::Sor>;
::testing::Types<::GaussSeidel, ::Ic, ::Ilu, ::Isai, ::Jacobi, ::Sor>;


TYPED_TEST_SUITE(Preconditioner, PreconditionerTypes, TypenameNameGenerator);
Expand Down
8 changes: 8 additions & 0 deletions include/ginkgo/core/preconditioner/gauss_seidel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ginkgo/core/base/composition.hpp>
#include <ginkgo/core/base/lin_op.hpp>
#include <ginkgo/core/base/polymorphic_object.hpp>
#include <ginkgo/core/config/config.hpp>


namespace gko {
Expand Down Expand Up @@ -83,6 +84,11 @@ class GaussSeidel
/** Creates a new parameter_type to set up the factory. */
static parameters_type build() { return {}; }

static parameters_type parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child =
config::make_type_descriptor<ValueType, IndexType>());

protected:
explicit GaussSeidel(std::shared_ptr<const Executor> exec,
const parameters_type& params = {})
Expand All @@ -96,6 +102,8 @@ class GaussSeidel
private:
parameters_type parameters_;
};


} // namespace preconditioner
} // namespace gko

Expand Down

0 comments on commit f8cbc05

Please sign in to comment.