Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Decouple dtype from shape for Random multinomial #15980

Merged
merged 4 commits into from
Aug 25, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/operator/random/sample_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!ndim_is_known(ishape)) return false;

MSHADOW_TYPE_SWITCH(param.dtype, DType, {
CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue<DType>())
<< "'dtype' does not have a sufficient precision to represent the indices of the input array.";
});

if (ishape.ndim() == 1) {
if (param.shape.ndim() > 0) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
Expand Down Expand Up @@ -121,7 +116,7 @@ inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,

struct SampleMultinomialKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, index_t K, index_t M,
MSHADOW_XINLINE static void Map(index_t i, index_t K, index_t M,
DType* dist, float* uniform, float* cum_table,
IType* out, DType* prob) {
double acc = 0.0;
Expand Down