Skip to content

Commit

Permalink
Add row and col scaling functions to distributed matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzgoebel committed Jul 3, 2024
1 parent 185c5f6 commit 9b99cc5
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 0 deletions.
50 changes: 50 additions & 0 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ginkgo/core/distributed/vector.hpp>
#include <ginkgo/core/matrix/coo.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/diagonal.hpp>

#include "core/distributed/matrix_kernels.hpp"

Expand Down Expand Up @@ -504,6 +505,55 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::col_scale(
ptr_param<const global_vector_type> scaling_factors)
{
GKO_ASSERT_CONFORMANT(this, scaling_factors.get());
auto exec = this->get_executor();
auto comm = this->get_communicator();
size_type n_local_cols = local_mtx_->get_size()[1];
size_type n_non_local_cols = non_local_mtx_->get_size()[1];
const auto scale_diag = gko::matrix::Diagonal<ValueType>::create_const(
exec, n_local_cols,
make_const_array_view(exec, n_local_cols,
scaling_factors->get_const_local_values()));

auto req = this->communicate(scaling_factors->get_local_vector());
scale_diag->rapply(local_mtx_, local_mtx_);
req.wait();
if (n_non_local_cols > 0) {
auto use_host_buffer = mpi::requires_host_buffer(exec, comm);
if (use_host_buffer) {
recv_buffer_->copy_from(host_recv_buffer_.get());
}
const auto non_local_scale_diag =
gko::matrix::Diagonal<ValueType>::create_const(
exec, n_non_local_cols,
make_const_array_view(exec, n_non_local_cols,
recv_buffer_->get_const_values()));
non_local_scale_diag->rapply(non_local_mtx_, non_local_mtx_);
}
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::row_scale(
ptr_param<const global_vector_type> scaling_factors)
{
GKO_ASSERT_EQUAL_ROWS(this, scaling_factors.get());
auto exec = this->get_executor();
size_type n_local_rows = local_mtx_->get_size()[0];
const auto scale_diag = gko::matrix::Diagonal<ValueType>::create_const(
exec, n_local_rows,
make_const_array_view(exec, n_local_rows,
scaling_factors->get_const_local_values()));

scale_diag->apply(local_mtx_, local_mtx_);
scale_diag->apply(non_local_mtx_, non_local_mtx_);
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
Matrix<ValueType, LocalIndexType, GlobalIndexType>::Matrix(const Matrix& other)
: EnableDistributedLinOp<Matrix<value_type, local_index_type,
Expand Down
18 changes: 18 additions & 0 deletions include/ginkgo/core/distributed/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,24 @@ class Matrix
std::vector<comm_index_type> recv_offsets,
array<local_index_type> recv_gather_idxs);

/**
* Scales the columns of the matrix by the respective entries of the vector.
* The vector's row partition has to be the same as the matrix's column
* partition. The scaling is done in-place.
*
* @param scaling_factors The vector containing the scaling factors.
*/
void col_scale(ptr_param<const global_vector_type> scaling_factors);

/**
* Scales the rows of the matrix by the respective entries of the vector.
* The vector and the matrix have to have the same row partition.
* The scaling is done in-place.
*
* @param scaling_factors The vector containing the scaling factors.
*/
void row_scale(ptr_param<const global_vector_type> scaling_factors);

protected:
explicit Matrix(std::shared_ptr<const Executor> exec,
mpi::communicator comm);
Expand Down
93 changes: 93 additions & 0 deletions test/mpi/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,99 @@ TYPED_TEST(Matrix, CanAdvancedApplyToMultipleVectorsLarge)
}


TYPED_TEST(Matrix, CanColScale)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using csr = typename TestFixture::local_matrix_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}, {5}}};
auto rank = this->comm.rank();

auto col_scaling_factors = dist_vec_type::create(this->exec, this->comm);
col_scaling_factors->read_distributed(vec_md, this->col_part);
this->dist_mat->col_scale(col_scaling_factors);

I<I<value_type>> res_col_scale_local[] = {
{{8, 0}, {0, 0}}, {{0, 10}, {0, 0}}, {{0}}};
I<I<value_type>> res_col_scale_non_local[] = {
{{2, 0}, {6, 12}}, {{0, 0, 18}, {32, 35, 0}}, {{50, 9}}};
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_local_matrix()),
res_col_scale_local[rank], 0);
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_non_local_matrix()),
res_col_scale_non_local[rank], 0);
}


TYPED_TEST(Matrix, CanRowScale)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using csr = typename TestFixture::local_matrix_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}, {5}}};
auto rank = this->comm.rank();

auto row_scaling_factors = dist_vec_type::create(this->exec, this->comm);
row_scaling_factors->read_distributed(vec_md, this->row_part);
this->dist_mat->row_scale(row_scaling_factors);

I<I<value_type>> res_row_scale_local[] = {
{{2, 0}, {0, 0}}, {{0, 15}, {0, 0}}, {{0}}};
I<I<value_type>> res_row_scale_non_local[] = {
{{1, 0}, {6, 8}}, {{0, 0, 18}, {32, 28, 0}}, {{50, 45}}};
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_local_matrix()),
res_row_scale_local[rank], 0);
GKO_ASSERT_MTX_NEAR(gko::as<csr>(this->dist_mat->get_non_local_matrix()),
res_row_scale_non_local[rank], 0);
}


TYPED_TEST(Matrix, ColScaleThrowsOnWrongDimension)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
using part_type = typename TestFixture::part_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}}};
auto rank = this->comm.rank();

auto col_part = part_type::build_from_mapping(
this->exec,
gko::array<gko::experimental::distributed::comm_index_type>(
this->exec,
I<gko::experimental::distributed::comm_index_type>{1, 2, 0, 0}),
3);
auto col_scaling_factors = dist_vec_type::create(this->exec, this->comm);
col_scaling_factors->read_distributed(vec_md, col_part);
ASSERT_THROW(this->dist_mat->col_scale(col_scaling_factors),
gko::DimensionMismatch);
}


TYPED_TEST(Matrix, RowScaleThrowsOnWrongDimension)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::global_index_type;
using dist_vec_type = typename TestFixture::dist_vec_type;
using part_type = typename TestFixture::part_type;
auto vec_md = gko::matrix_data<value_type, index_type>{
I<I<value_type>>{{1}, {2}, {3}, {4}}};
auto rank = this->comm.rank();

auto row_part = part_type::build_from_contiguous(
this->exec,
gko::array<index_type>(this->exec, I<index_type>{0, 2, 3, 4}));
auto row_scaling_factors = dist_vec_type::create(this->exec, this->comm);
row_scaling_factors->read_distributed(vec_md, row_part);
ASSERT_THROW(this->dist_mat->row_scale(row_scaling_factors),
gko::DimensionMismatch);
}


TYPED_TEST(Matrix, CanConvertToNextPrecision)
{
using T = typename TestFixture::value_type;
Expand Down

0 comments on commit 9b99cc5

Please sign in to comment.