Skip to content

Commit

Permalink
Merge fix for Matrix constructor with LinOp* parameter
Browse files Browse the repository at this point in the history
This PR fixes the distributed matrix creation with local and non-local templates passed in as LinOp*. Before, the LinOp* overload would only be used if the input argument has exactly the type const LinOp*. In all other cases, eg an input as LinOp* or Csr<...>*, the templated constructor would be chosen, because, after template deduction, that constructor matches better than the LinOp* one. To prevent this from happening, this PR enables the templated constructor only if the template parameter has a `create<ValueType, IndexType>(exec)` function.

Related PR: #1148
  • Loading branch information
MarcelKoch committed Nov 1, 2022
2 parents 9242dd3 + f847a60 commit 747fc7e
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 25 deletions.
2 changes: 1 addition & 1 deletion core/test/mpi/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ginkgo_create_test(matrix MPI_SIZE 3)
ginkgo_create_test(matrix MPI_SIZE 1)
76 changes: 68 additions & 8 deletions core/test/mpi/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ class MatrixBuilder : public ::testing::Test {
{
SCOPED_TRACE("With Csr with strategy");
using ConcreteCsr = Csr<value_type, local_index_type>;
f(gko::with_matrix_type<Csr>(
std::make_shared<typename ConcreteCsr::classical>()),
ConcreteCsr::create(this->ref), [](const gko::LinOp* local_mat) {
auto strategy = std::make_shared<typename ConcreteCsr::classical>();
f(gko::with_matrix_type<Csr>(strategy),
ConcreteCsr::create(this->ref, strategy),
[](const gko::LinOp* local_mat) {
auto local_csr = gko::as<ConcreteCsr>(local_mat);

ASSERT_NO_THROW(gko::as<typename ConcreteCsr::classical>(
Expand All @@ -141,7 +142,7 @@ class MatrixBuilder : public ::testing::Test {
{
SCOPED_TRACE("With Fbcsr with block_size");
f(gko::with_matrix_type<Fbcsr>(5),
Fbcsr<value_type, local_index_type>::create(this->ref),
Fbcsr<value_type, local_index_type>::create(this->ref, 5),
[](const gko::LinOp* local_mat) {
auto local_fbcsr =
gko::as<Fbcsr<value_type, local_index_type>>(local_mat);
Expand All @@ -158,9 +159,11 @@ class MatrixBuilder : public ::testing::Test {
{
SCOPED_TRACE("With Hybrid with strategy");
using Concrete = Hybrid<value_type, local_index_type>;
f(gko::with_matrix_type<Hybrid>(
std::make_shared<typename Concrete::column_limit>(11)),
Concrete::create(this->ref), [](const gko::LinOp* local_mat) {
auto strategy =
std::make_shared<typename Concrete::column_limit>(11);
f(gko::with_matrix_type<Hybrid>(strategy),
Concrete::create(this->ref, strategy),
[](const gko::LinOp* local_mat) {
auto local_hy = gko::as<Concrete>(local_mat);

ASSERT_NO_THROW(gko::as<typename Concrete::column_limit>(
Expand Down Expand Up @@ -238,7 +241,7 @@ TYPED_TEST(MatrixBuilder, BuildWithLocalAndNonLocal)
auto additional_local_test) {
using expected_local_type = typename std::remove_pointer<decltype(
expected_local_type_ptr.get())>::type;
this->forall_matrix_types([=](auto with_non_local_matrix_type,
this->forall_matrix_types([&](auto with_non_local_matrix_type,
auto expected_non_local_type_ptr,
auto additional_non_local_test) {
using expected_non_local_type =
Expand Down Expand Up @@ -279,4 +282,61 @@ TYPED_TEST(MatrixBuilder, BuildWithCustomLinOp)
}


TYPED_TEST(MatrixBuilder, BuildFromLinOpLocal)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::local_index_type;
using dist_mat_type = typename TestFixture::dist_mtx_type;
this->template forall_matrix_types([this](auto with_matrix_type,
auto expected_type_ptr,
auto additional_test) {
using expected_type = typename std::remove_pointer<decltype(
expected_type_ptr.get())>::type;

auto mat = dist_mat_type ::create(this->ref, this->comm,
expected_type_ptr.get());

ASSERT_NO_THROW(gko::as<expected_type>(mat->get_local_matrix()));
additional_test(mat->get_local_matrix().get());
additional_test(mat->get_non_local_matrix().get());
this->expected_interface_no_throw(mat.get(), with_matrix_type,
with_matrix_type);
});
}


TYPED_TEST(MatrixBuilder, BuildFromLinOpLocalAndNonLocal)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::local_index_type;
using dist_mat_type = typename TestFixture::dist_mtx_type;
this->template forall_matrix_types([this](auto with_local_matrix_type,
auto expected_local_type_ptr,
auto additional_local_test) {
using expected_local_type = typename std::remove_pointer<decltype(
expected_local_type_ptr.get())>::type;
this->forall_matrix_types([&](auto with_non_local_matrix_type,
auto expected_non_local_type_ptr,
auto additional_non_local_test) {
using expected_non_local_type =
typename std::remove_pointer<decltype(
expected_non_local_type_ptr.get())>::type;

auto mat = dist_mat_type ::create(
this->ref, this->comm, expected_local_type_ptr.get(),
expected_non_local_type_ptr.get());

ASSERT_NO_THROW(
gko::as<expected_local_type>(mat->get_local_matrix()));
ASSERT_NO_THROW(
gko::as<expected_non_local_type>(mat->get_non_local_matrix()));
additional_local_test(mat->get_local_matrix().get());
additional_non_local_test(mat->get_non_local_matrix().get());
this->expected_interface_no_throw(mat.get(), with_local_matrix_type,
with_non_local_matrix_type);
});
});
}


} // namespace
54 changes: 38 additions & 16 deletions include/ginkgo/core/distributed/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,24 @@ class Csr;
namespace detail {


/**
* Helper struct to test if the Builder type has a function create<ValueType,
* IndexType>(std::shared_ptr<const Executor>).
*/
template <typename Builder, typename ValueType, typename IndexType,
typename = void>
struct is_matrix_type_builder : std::false_type {};


template <typename Builder, typename ValueType, typename IndexType>
struct is_matrix_type_builder<
Builder, ValueType, IndexType,
gko::xstd::void_t<decltype(
std::declval<Builder>().template create<ValueType, IndexType>(
std::declval<std::shared_ptr<const Executor>>()))>>
: std::true_type {};


template <template <typename, typename> class MatrixType,
typename... CreateArgs>
struct MatrixTypeBuilderFromValueAndIndex {
Expand Down Expand Up @@ -123,7 +141,7 @@ template <template <typename, typename> class MatrixType, typename... Args>
auto with_matrix_type(Args&&... create_args)
{
return detail::MatrixTypeBuilderFromValueAndIndex<MatrixType, Args...>{
std::make_tuple(create_args...)};
std::forward_as_tuple(create_args...)};
}


Expand Down Expand Up @@ -417,14 +435,15 @@ class Matrix
* same type as `create` returns. It should be the
* return value of make_matrix_template.
*/
template <typename MatrixType>
template <typename MatrixType,
typename = std::enable_if_t<detail::is_matrix_type_builder<
MatrixType, ValueType, LocalIndexType>::value>>
explicit Matrix(std::shared_ptr<const Executor> exec,
mpi::communicator comm, MatrixType matrix_template)
: Matrix(exec, comm,
static_cast<const LinOp*>(
matrix_template
.template create<ValueType, LocalIndexType>(exec)
.get()))
: Matrix(
exec, comm,
matrix_template.template create<ValueType, LocalIndexType>(exec)
.get())
{}

/**
Expand Down Expand Up @@ -453,20 +472,23 @@ class Matrix
* `create` returns. It should be the
* return value of make_matrix_template.
*/
template <typename LocalMatrixType, typename NonLocalMatrixType>
template <typename LocalMatrixType, typename NonLocalMatrixType,
typename = std::enable_if_t<
detail::is_matrix_type_builder<LocalMatrixType, ValueType,
LocalIndexType>::value &&
detail::is_matrix_type_builder<NonLocalMatrixType, ValueType,
LocalIndexType>::value>>
explicit Matrix(std::shared_ptr<const Executor> exec,
mpi::communicator comm,
LocalMatrixType local_matrix_template,
NonLocalMatrixType non_local_matrix_template)
: Matrix(exec, comm,
static_cast<const LinOp*>(
local_matrix_template
.template create<ValueType, LocalIndexType>(exec)
.get()),
static_cast<const LinOp*>(
non_local_matrix_template
.template create<ValueType, LocalIndexType>(exec)
.get()))
local_matrix_template
.template create<ValueType, LocalIndexType>(exec)
.get(),
non_local_matrix_template
.template create<ValueType, LocalIndexType>(exec)
.get())
{}

/**
Expand Down

0 comments on commit 747fc7e

Please sign in to comment.