Skip to content

Commit

Permalink
Remove workaround and call kernels directly
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Feb 15, 2022
1 parent 99ff25d commit 889a53a
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 140 deletions.
103 changes: 33 additions & 70 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,62 +744,6 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL);


// TODO: FIXME
namespace index_set {


template <typename IndexType>
Array<IndexType> map_global_to_local(
std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set,
const Array<IndexType>& global_indices, const bool is_sorted)
{
auto local_indices =
gko::Array<IndexType>(exec, global_indices.get_num_elems());

GKO_ASSERT(index_set.get_num_subsets() >= 1);
gko::kernels::omp::index_set::global_to_local(
exec, index_set.get_size(), index_set.get_num_subsets(),
index_set.get_subsets_begin(), index_set.get_subsets_end(),
index_set.get_superset_indices(),
static_cast<IndexType>(local_indices.get_num_elems()),
global_indices.get_const_data(), local_indices.get_data(), is_sorted);
return local_indices;
}


template <typename IndexType>
IndexType get_local_index(std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set,
const IndexType index)
{
const auto global_idx =
Array<IndexType>(exec, std::initializer_list<IndexType>{index});
auto local_idx = Array<IndexType>(
exec,
index_set::map_global_to_local(exec, index_set, global_idx, true));

return exec->copy_val_to_host(local_idx.get_data());
}


template <typename IndexType>
bool contains(std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set, const IndexType input_index)
{
if (input_index >= index_set.get_size()) {
return false;
} else {
auto local_index =
index_set::get_local_index(exec, index_set, input_index);
return local_index != invalid_index<IndexType>();
}
}


} // namespace index_set


template <typename ValueType, typename IndexType>
void calculate_nonzeros_per_row_in_index_set(
std::shared_ptr<const DefaultExecutor> exec,
Expand All @@ -811,14 +755,26 @@ void calculate_nonzeros_per_row_in_index_set(
auto num_row_subsets = row_index_set.get_num_subsets();
auto row_subset_begin = row_index_set.get_subsets_begin();
auto row_subset_end = row_index_set.get_subsets_end();
auto src_ptrs = source->get_const_row_ptrs();
for (size_type set = 0; set < num_row_subsets; ++set) {
for (size_type row = row_subset_begin[set]; row < row_subset_end[set];
++row) {
row_nnz->get_data()[res_row] = zero<IndexType>();
for (size_type nnz = source->get_const_row_ptrs()[row];
nnz < source->get_const_row_ptrs()[row + 1]; ++nnz) {
if (index_set::contains(exec, col_index_set,
source->get_const_col_idxs()[nnz])) {
Array<IndexType> l_idxs(
exec,
static_cast<size_type>(src_ptrs[row + 1] - src_ptrs[row]));
gko::kernels::omp::index_set::global_to_local(
exec, col_index_set.get_size(), col_index_set.get_num_subsets(),
col_index_set.get_subsets_begin(),
col_index_set.get_subsets_end(),
col_index_set.get_superset_indices(),
static_cast<IndexType>(l_idxs.get_num_elems()),
source->get_const_col_idxs() + src_ptrs[row], l_idxs.get_data(),
false);
for (size_type nnz = 0; nnz < (src_ptrs[row + 1] - src_ptrs[row]);
++nnz) {
auto l_idx = l_idxs.get_const_data()[nnz];
if (l_idx != invalid_index<IndexType>()) {
row_nnz->get_data()[res_row]++;
}
}
Expand Down Expand Up @@ -888,19 +844,26 @@ void compute_submatrix_from_index_set(
for (size_type set = 0; set < num_row_subsets; ++set) {
for (size_type row = row_subset_begin[set]; row < row_subset_end[set];
++row) {
auto local_map = std::vector<IndexType>(
src_row_ptrs[row + 1] - src_row_ptrs[row], 0);
for (size_type nnz = src_row_ptrs[row]; nnz < src_row_ptrs[row + 1];
++nnz) {
if (index_set::contains(exec, col_index_set,
src_col_idxs[nnz])) {
res_col_idxs[res_nnz] = index_set::get_local_index(
exec, col_index_set, src_col_idxs[nnz]);
res_values[res_nnz] = src_values[nnz];
Array<IndexType> l_idxs(
exec, static_cast<size_type>(src_row_ptrs[row + 1] -
src_row_ptrs[row]));
gko::kernels::omp::index_set::global_to_local(
exec, col_index_set.get_size(), col_index_set.get_num_subsets(),
col_index_set.get_subsets_begin(),
col_index_set.get_subsets_end(),
col_index_set.get_superset_indices(),
static_cast<IndexType>(l_idxs.get_num_elems()),
source->get_const_col_idxs() + src_row_ptrs[row],
l_idxs.get_data(), false);
for (size_type nnz = 0;
nnz < (src_row_ptrs[row + 1] - src_row_ptrs[row]); ++nnz) {
auto l_idx = l_idxs.get_const_data()[nnz];
if (l_idx != invalid_index<IndexType>()) {
res_col_idxs[res_nnz] = l_idx;
res_values[res_nnz] = src_values[nnz + src_row_ptrs[row]];
res_nnz++;
}
}
// res_nnz = res_row_ptrs[row_index_set.get_local_index(row)];
}
}
}
Expand Down
103 changes: 33 additions & 70 deletions reference/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,62 +625,6 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL);


// TODO: FIXME
namespace index_set {


template <typename IndexType>
Array<IndexType> map_global_to_local(
std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set,
const Array<IndexType>& global_indices, const bool is_sorted)
{
auto local_indices =
gko::Array<IndexType>(exec, global_indices.get_num_elems());

GKO_ASSERT(index_set.get_num_subsets() >= 1);
gko::kernels::reference::index_set::global_to_local(
exec, index_set.get_size(), index_set.get_num_subsets(),
index_set.get_subsets_begin(), index_set.get_subsets_end(),
index_set.get_superset_indices(),
static_cast<IndexType>(local_indices.get_num_elems()),
global_indices.get_const_data(), local_indices.get_data(), is_sorted);
return local_indices;
}


template <typename IndexType>
IndexType get_local_index(std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set,
const IndexType index)
{
const auto global_idx =
Array<IndexType>(exec, std::initializer_list<IndexType>{index});
auto local_idx = Array<IndexType>(
exec,
index_set::map_global_to_local(exec, index_set, global_idx, true));

return exec->copy_val_to_host(local_idx.get_data());
}


template <typename IndexType>
bool contains(std::shared_ptr<const DefaultExecutor> exec,
const IndexSet<IndexType>& index_set, const IndexType input_index)
{
if (input_index >= index_set.get_size()) {
return false;
} else {
auto local_index =
index_set::get_local_index(exec, index_set, input_index);
return local_index != invalid_index<IndexType>();
}
}


} // namespace index_set


template <typename ValueType, typename IndexType>
void calculate_nonzeros_per_row_in_index_set(
std::shared_ptr<const DefaultExecutor> exec,
Expand All @@ -692,14 +636,26 @@ void calculate_nonzeros_per_row_in_index_set(
auto num_row_subsets = row_index_set.get_num_subsets();
auto row_subset_begin = row_index_set.get_subsets_begin();
auto row_subset_end = row_index_set.get_subsets_end();
auto src_ptrs = source->get_const_row_ptrs();
for (size_type set = 0; set < num_row_subsets; ++set) {
for (size_type row = row_subset_begin[set]; row < row_subset_end[set];
++row) {
row_nnz->get_data()[res_row] = zero<IndexType>();
for (size_type nnz = source->get_const_row_ptrs()[row];
nnz < source->get_const_row_ptrs()[row + 1]; ++nnz) {
if (index_set::contains(exec, col_index_set,
source->get_const_col_idxs()[nnz])) {
Array<IndexType> l_idxs(
exec,
static_cast<size_type>(src_ptrs[row + 1] - src_ptrs[row]));
gko::kernels::reference::index_set::global_to_local(
exec, col_index_set.get_size(), col_index_set.get_num_subsets(),
col_index_set.get_subsets_begin(),
col_index_set.get_subsets_end(),
col_index_set.get_superset_indices(),
static_cast<IndexType>(l_idxs.get_num_elems()),
source->get_const_col_idxs() + src_ptrs[row], l_idxs.get_data(),
false);
for (size_type nnz = 0; nnz < (src_ptrs[row + 1] - src_ptrs[row]);
++nnz) {
auto l_idx = l_idxs.get_const_data()[nnz];
if (l_idx != invalid_index<IndexType>()) {
row_nnz->get_data()[res_row]++;
}
}
Expand Down Expand Up @@ -770,19 +726,26 @@ void compute_submatrix_from_index_set(
for (size_type set = 0; set < num_row_subsets; ++set) {
for (size_type row = row_subset_begin[set]; row < row_subset_end[set];
++row) {
auto local_map = std::vector<IndexType>(
src_row_ptrs[row + 1] - src_row_ptrs[row], 0);
for (size_type nnz = src_row_ptrs[row]; nnz < src_row_ptrs[row + 1];
++nnz) {
if (index_set::contains(exec, col_index_set,
src_col_idxs[nnz])) {
res_col_idxs[res_nnz] = index_set::get_local_index(
exec, col_index_set, src_col_idxs[nnz]);
res_values[res_nnz] = src_values[nnz];
Array<IndexType> l_idxs(
exec, static_cast<size_type>(src_row_ptrs[row + 1] -
src_row_ptrs[row]));
gko::kernels::reference::index_set::global_to_local(
exec, col_index_set.get_size(), col_index_set.get_num_subsets(),
col_index_set.get_subsets_begin(),
col_index_set.get_subsets_end(),
col_index_set.get_superset_indices(),
static_cast<IndexType>(l_idxs.get_num_elems()),
source->get_const_col_idxs() + src_row_ptrs[row],
l_idxs.get_data(), false);
for (size_type nnz = 0;
nnz < (src_row_ptrs[row + 1] - src_row_ptrs[row]); ++nnz) {
auto l_idx = l_idxs.get_const_data()[nnz];
if (l_idx != invalid_index<IndexType>()) {
res_col_idxs[res_nnz] = l_idx;
res_values[res_nnz] = src_values[nnz + src_row_ptrs[row]];
res_nnz++;
}
}
// res_nnz = res_row_ptrs[row_index_set.get_local_index(row)];
}
}
}
Expand Down

0 comments on commit 889a53a

Please sign in to comment.