Skip to content

Commit

Permalink
fix(embedding_bag): code style
Browse files Browse the repository at this point in the history
  • Loading branch information
xlcjz committed Nov 29, 2023
1 parent d0644f6 commit d33024a
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 55 deletions.
3 changes: 0 additions & 3 deletions paddle/phi/core/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
#endif

#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
Expand Down Expand Up @@ -156,7 +154,6 @@ struct DeviceContext::Impl {
ClearHolder(tensor);
}
} else {
VLOG(0) << "Segment Fault is about to come.";
if (tensor->initialized() && tensor->place() != place) {
ClearHolder(tensor);
}
Expand Down
21 changes: 12 additions & 9 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1499,20 +1499,20 @@ void EmbeddingBagInferMeta(const MetaTensor& input,
const MetaTensor& per_sample_weight,
MetaTensor* out) {
const auto& table_dims = weight.dims();
auto table_dims_size = table_dims.size();
const auto& ids_dims = input.dims();
auto ids_dims_size = ids_dims.size();
const auto& weight_dims = per_sample_weight.dims();
int ids_rank = ids_dims.size();
VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ(
ids_dims,
weight_dims,
phi::errors::InvalidArgument(
"ShapeError: The shapes of 'input' and 'per_sample_weight' must be the same."
"But received input's shape = [%s],"
"per_sample_weight's shape = [%s].",
ids_dims,
weight_dims));
phi::errors::InvalidArgument("ShapeError: The shapes of 'input' and "
"'per_sample_weight' must be the same."
"But received input's shape = [%s],"
"per_sample_weight's shape = [%s].",
ids_dims,
weight_dims));
PADDLE_ENFORCE_EQ(
table_dims.size(),
2,
Expand All @@ -1524,11 +1524,14 @@ void EmbeddingBagInferMeta(const MetaTensor& input,
table_dims));

auto output_dims =
phi::vectorize(phi::slice_ddim(ids_dims, 0, table_dims_size - 1));
phi::vectorize(phi::slice_ddim(ids_dims, 0, ids_dims_size - 1));
output_dims.push_back(table_dims[1]);
out->set_dims(phi::make_ddim(output_dims));
for (auto i : output_dims) {
VLOG(5) << i << " ";
}
out->set_dtype(weight.dtype());
out->share_lod(input);
VLOG(5) << "EmbeddingBagInferMeta End.";
}

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ void BoxCoderInferMeta(const MetaTensor& prior_box,
MetaConfig config = MetaConfig());

void EmbeddingBagInferMeta(const MetaTensor& input,
const MetaTensor& params,
const MetaTensor& weight,
const MetaTensor& per_sample_weight,
MetaTensor* out);

void DpsgdInferMeta(const MetaTensor& param,
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/kernels/cpu/embedding_bag_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ struct EmbeddingBagCPUFunctor {
mode_(mode),
out_(out) {}

using EigenArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
using EigenVectorMap = Eigen::Map<Eigen::Vector<T, Eigen::Dynamic>>;
using ConstEigenVectorMap = Eigen::Map<const Eigen::Vector<T, Eigen::Dynamic>>;
using ConstEigenVectorMap =
Eigen::Map<const Eigen::Vector<T, Eigen::Dynamic>>;
using EigenIndex = Eigen::Index;

template <typename IdT>
Expand All @@ -61,7 +63,8 @@ struct EmbeddingBagCPUFunctor {
const ConstEigenVectorMap weight_slice(
&weight_d[input_d[bag * sequence_length + seq] * output_dim],
output_dim);
output_slice += weight_slice * per_sample_weight_d[bag * sequence_length + seq];
output_slice +=
weight_slice * per_sample_weight_d[bag * sequence_length + seq];
}
if (mode_ == "mean") {
output_slice /= static_cast<T>(sequence_length);
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/embedding_bag_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

namespace phi {

enum class CalMode { ksum, kmean, kmax};
enum class CalMode { ksum, kmean, kmax };

template <typename T, typename Context>
void EmbeddingBagCUDAKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& per_sample_weight,
int64_t padding_idx,
const std::string &mode,
DenseTensor *out);
const std::string& mode,
DenseTensor* out);
} // namespace phi
19 changes: 10 additions & 9 deletions paddle/phi/kernels/gpu/embedding_bag_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <thrust/sort.h>
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"

namespace phi {

enum class CalMode_c { ksum, kmean, kmax};
enum class CalMode_c { ksum, kmean, kmax };

// kernelfunc, calculate the grad of the variable 'weight'
template <typename T, typename IdT>
Expand Down Expand Up @@ -56,7 +56,7 @@ __global__ void EmbeddingBagWeightsGrad(const int output_dim,
}
// asist in obtain the map between the indices and the rows of params
// can refer 'index_vec' in embedding_bag_grad_kernel.cc(in line 83)
//
//
template <typename IdT>
__global__ void PrepTempArraysKernel(const IdT *indices,
IdT *sortedIndices,
Expand All @@ -83,9 +83,10 @@ __global__ void EmbeddingBagParamsGrad(const int output_dim,
const int feature_idx = threadIdx.x + bag_idx * blockDim.x;
const int params_idx = __ldg(sortedIndices + sample_idx);
// refer embeddingbag in tensorflow/addons, spin up a warp for each element
// of the indices array, having each warp check the previous element,
// of the indices array, having each warp check the previous element,
// if the same, return without operations. If not, the warp iterates forward
// and accumulates gradient. The operation is to avoid repeated reads and writes
// and accumulates gradient. The operation is to avoid repeated reads and
// writes
if (sample_idx > 0) {
const int prev_idx = __ldg(sortedIndices + sample_idx - 1);
if (prev_idx == params_idx) {
Expand Down Expand Up @@ -192,10 +193,10 @@ struct EmbeddingBagGradCUDAFunctor {

dim3 grids_2(total_blocks, 1, 1);

// the target of these operations is to avoid parallel writes to the same element of
// the grads. So 'PrepTempArraysKernel' is designed to pre-sorting a copy of the indices(sourtedIndices),
// and co-sorting a counter(sortedIndicesCounter).

// the target of these operations is to avoid parallel writes to the same
// element of the grads. So 'PrepTempArraysKernel' is designed to
// pre-sorting a copy of the indices(sourtedIndices), and co-sorting a
// counter(sortedIndicesCounter).

PrepTempArraysKernel<IdT>
<<<grids_2, kThreadsPerBlock, 0, dev_ctx_.stream()>>>(
Expand Down
49 changes: 35 additions & 14 deletions paddle/phi/kernels/gpu/embedding_bag_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ __global__ void EmbeddingBag(T *output,
int padding_idx_count = 0;
T sum = static_cast<T>(0);
T max_d = static_cast<T>(0);
for (int j = 0; j < S; j ++) {
for (int j = 0; j < S; j++) {
auto id = static_cast<int64_t>(ids[idy * S + j]);
const T *tab = table + id * D;
if (PaddingFlag && id == padding_idx) {
Expand All @@ -57,8 +57,10 @@ __global__ void EmbeddingBag(T *output,
if (mode == CalMode::ksum) {
out[i] = sum;
} else if (mode == CalMode::kmean) {
if (padding_idx_count == S) out[i] = static_cast<T>(0);
else out[i] = sum / (S - padding_idx_count);
if (padding_idx_count == S)
out[i] = static_cast<T>(0);
else
out[i] = sum / (S - padding_idx_count);
} else {
out[i] = max_d;
}
Expand All @@ -73,7 +75,7 @@ struct EmbeddingBagCUDAFunctor {
const DenseTensor &input,
const DenseTensor &weight,
const DenseTensor &per_sample_weight,
const int64_t padding_idx,
int64_t padding_idx,
const std::string &mode,
DenseTensor *out)
: dev_ctx_(dev_ctx),
Expand All @@ -91,6 +93,11 @@ struct EmbeddingBagCUDAFunctor {
size_t K = input_.numel();
size_t S = input_.dims()[1];

printf("N;D;K;S %ld %ld %ld %ld\n", N, D, K, S);
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 blocks(256, 4);
dim3 grids(gridx, 1);

const T *weight_d = weight_.data<T>();
const IdT *ids_d = input_.data<IdT>();
const T *per_sample_weight_d = per_sample_weight_.data<T>();
Expand All @@ -99,20 +106,34 @@ struct EmbeddingBagCUDAFunctor {
auto stream = dev_ctx_.stream();
printf("After Alloc\n");

const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 blocks(256, 4);
dim3 grids(gridx, 1);

CalMode mode_enum = CalMode::ksum;
if (mode_ == "mean") mode_enum = CalMode::kmean;
if (mode_ == "max") mode_enum = CalMode::kmax;

if (padding_idx_ == -1) {
EmbeddingBag<T, IdT, false><<<grids, blocks, 0, stream>>>(
output_d, weight_d, ids_d, per_sample_weight_d, N, K, D, S, padding_idx_, mode_enum);
EmbeddingBag<T, IdT, false>
<<<grids, blocks, 0, stream>>>(output_d,
weight_d,
ids_d,
per_sample_weight_d,
N,
K,
D,
S,
padding_idx_,
mode_enum);
} else {
EmbeddingBag<T, IdT, true><<<grids, blocks, 0, stream>>>(
output_d, weight_d, ids_d, per_sample_weight_d, N, K, D, S, padding_idx_, mode_enum);
EmbeddingBag<T, IdT, true>
<<<grids, blocks, 0, stream>>>(output_d,
weight_d,
ids_d,
per_sample_weight_d,
N,
K,
D,
S,
padding_idx_,
mode_enum);
}
}

Expand All @@ -121,8 +142,8 @@ struct EmbeddingBagCUDAFunctor {
const DenseTensor &input_;
const DenseTensor &weight_;
const DenseTensor &per_sample_weight_;
const std::string& mode_;
const int64_t padding_idx_;
int64_t padding_idx_;
const std::string &mode_;
DenseTensor *out_;
};

Expand Down
24 changes: 17 additions & 7 deletions python/paddle/nn/functional/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,15 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
return tmp


def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode="sum", sparse=False, name=None):
def embedding_bag(
input,
weight,
per_sample_weight=None,
padding_idx=None,
mode="sum",
sparse=False,
name=None,
):
"""
Used to calculate the sum ,mean, or max of the specified bag in the embeddings vector by : attr:'input'.
Each bag contains several row indexes of embeddings.
Expand Down Expand Up @@ -289,9 +297,9 @@ def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode=
such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`.
In these cases, sparse must be False. Default: False.
padding_idx(int|long|None, optional): padding_idx needs to be in the interval [-weight.shape[0], weight.shape[0]).
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
to :math:`weight.shape[0] + padding\_idx` . It will output all-zero padding data whenever lookup
encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
If :math:`padding_idx < 0`, the :math:`padding_idx` will automatically be converted
to :math:`weight.shape[0] + padding_idx` . It will output all-zero padding data whenever lookup
encounters :math:`padding_idx` in id. And the padding data will not be updated while training.
If set None, it makes no effect to output. Default: None.
mode(str): Specifies the way to reduce the bag. "sum" computes the weighted sum, taking weight into consideration. "mean" computes the
average of the values in the bag, "max" computes the max value over each bag. Default: "mean"
Expand All @@ -308,7 +316,7 @@ def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode=
>>> input = np.random.randint(low=0, high=10, size = (2,6)).astype(np.int64)
>>> input = paddle.to_tensor(input, stop_gradient = False)
>>> per_sample_weight = np.random.random((2,6)).astype(np.float32)
>>> per_sample_weight = paddle.to_tensor(per_sample_weight, stop_gradient = False)
>>> per_sample_weight = paddle.to_tensor(per_sample_weight, stop_gradient = False)
>>> weight = np.random.random((10,3)).astype(np.float32)
>>> weight = paddle.to_tensor(weight, stop_gradient = False)
>>> sum = nn.functional.embedding_bag(input, weight, per_sample_weight, mode='sum')
Expand All @@ -330,7 +338,9 @@ def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode=
)

if in_dynamic_or_pir_mode():
return _C_ops.embedding_bag(input, weight, per_sample_weight, padding_idx, mode, sparse, name)
return _C_ops.embedding_bag(
input, weight, per_sample_weight, padding_idx, mode, sparse
)
else:
helper = LayerHelper('embedding_bag', **locals())
dtype = helper.input_dtype(input_param_name='weight')
Expand Down Expand Up @@ -360,7 +370,7 @@ def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode=
'is_distributed': is_distributed,
'remote_prefetch': remote_prefetch,
'padding_idx': padding_idx,
'mode': mode
'mode': mode,
},
)
return tmp
18 changes: 12 additions & 6 deletions test/legacy_test/test_embeddingbag_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def manual_embeddingbag(input, params, weights=None, mode="sum"):
def get_input(rows=5, cols=3, num_embeddings=10):
a = np.random.choice(np.arange(num_embeddings), size=cols, replace=False)
for _ in range(rows - 1):
b = np.random.choice(np.arange(num_embeddings), size=cols, replace=False)
b = np.random.choice(
np.arange(num_embeddings), size=cols, replace=False
)
a = np.vstack((a, b))
return a

Expand All @@ -53,11 +55,15 @@ def setUp(self):
self.python_api = paddle.nn.functional.embedding_bag
weight = np.random.random((20, 64)).astype(self.dtype)
input = get_input(10, 20, weight.shape[0])
per_sample_weight = np.random.randint(low=0, high=10, size=input.shape).astype(
np.float64
)

self.inputs = {'input': input, 'weight': weight, 'per_sample_weight': per_sample_weight}
per_sample_weight = np.random.randint(
low=0, high=10, size=input.shape
).astype(np.float64)

self.inputs = {
'input': input,
'weight': weight,
'per_sample_weight': per_sample_weight,
}
np_out = manual_embeddingbag(input, weight, per_sample_weight)
self.outputs = {
'out': np_out.reshape((input.shape[0], weight.shape[1]))
Expand Down

0 comments on commit d33024a

Please sign in to comment.