Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support pure boolean elemwise/broadcast binary op
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 6, 2019
1 parent 3c404a5 commit 4c0464d
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 52 deletions.
22 changes: 17 additions & 5 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ using std::is_integral;
} \
}

#define MXNET_BINARY_MATH_OP_NC_WITH_BOOL(name, expr) \
struct name : public mxnet_op::tunable { \
template<typename DType, \
typename std::enable_if<!std::is_same<DType, bool>::value, int>::type = 0> \
MSHADOW_XINLINE static DType Map(DType a, DType b) { \
return (expr); \
} \
MSHADOW_XINLINE static bool Map(bool a, bool b) { \
return (expr); \
} \
}

#define MXNET_BINARY_LOGIC_OP_NC(name, expr) \
struct name : public mxnet_op::tunable { \
template<typename DType> \
Expand Down Expand Up @@ -192,8 +204,6 @@ MXNET_BINARY_MATH_OP_NC(left, a);

MXNET_BINARY_MATH_OP_NC(right, b);

MXNET_BINARY_MATH_OP_NC(mul, a * b);

#ifndef _WIN32
struct mixed_plus {
template<typename DType,
Expand Down Expand Up @@ -288,11 +298,13 @@ struct mixed_mul {
};
#endif

MXNET_BINARY_MATH_OP_NC(div, a / b);
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a * b);

MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b);

MXNET_BINARY_MATH_OP_NC(plus, a + b);
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(plus, a + b);

MXNET_BINARY_MATH_OP_NC(minus, a - b);
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(minus, a - b);

MXNET_UNARY_MATH_OP(negation, -a);

Expand Down
29 changes: 13 additions & 16 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
return true;
}

#ifdef _WIN32
#ifndef _WIN32
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
Expand All @@ -85,10 +85,6 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
#else
Expand All @@ -106,6 +102,10 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
}) \
.set_attr<FResourceRequest>("FResourceRequest", \
[](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
}) \
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
#endif
Expand All @@ -114,41 +114,38 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
op::mshadow_op::mixed_plus>)
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
op::mshadow_op::mixed_plus>)
#else
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::plus, op::mshadow_op::plus,
op::mshadow_op::plus>)
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});

MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
op::mshadow_op::mixed_rminus>)
#else
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::minus, op::mshadow_op::minus,
op::mshadow_op::minus>)
NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});

MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>)
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>)
#else
.set_attr<FCompute>(
"FCompute<cpu>",
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mul,
op::mshadow_op::mul>)
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});

Expand Down
19 changes: 8 additions & 11 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,36 @@ NNVM_REGISTER_OP(_npi_add)
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
op::mshadow_op::mixed_plus>);
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
op::mshadow_op::mixed_plus>);
#else
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::plus, op::mshadow_op::plus,
op::mshadow_op::plus>);
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus>);
#endif

NNVM_REGISTER_OP(_npi_subtract)
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
op::mshadow_op::mixed_rminus>);
#else
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::minus, op::mshadow_op::minus,
op::mshadow_op::minus>);
NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
#endif

NNVM_REGISTER_OP(_npi_multiply)
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>);
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
op::mshadow_op::mixed_mul>);
#else
.set_attr<FCompute>(
"FCompute<gpu>",
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mul,
op::mshadow_op::mul>);
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>);
#endif

NNVM_REGISTER_OP(_npi_mod)
Expand Down
85 changes: 74 additions & 11 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
namespace mxnet {
namespace op {

inline void PrintErrorMessage(const std::string& name, const int dtype1, const int dtype2) {
LOG(FATAL) << "Operator " << name << " does not support combination of "
inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) {
LOG(FATAL) << "Operator " << op_name << " does not support combination of "
<< common::dtype_string(dtype1) << " with " << common::dtype_string(dtype2)
<< " yet...";
}
Expand Down Expand Up @@ -218,7 +218,11 @@ void MixedAllRealBinaryBroadcastCompute(const std::string& op_name,
}
#endif

#ifndef _WIN32
template<typename xpu, typename OP, typename LOP, typename ROP>
#else
template<typename xpu, typename OP>
#endif
void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand All @@ -233,13 +237,6 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];

if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;

if (lhs.type_flag_ == rhs.type_flag_) {
BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
return;
}

#ifndef _WIN32
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
Expand Down Expand Up @@ -299,7 +296,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
BinaryBroadcastCompute<xpu, OP, allow_bool>(
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, {
Expand All @@ -308,7 +305,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
temp_tblob = TBlob(temp_tensor);
});
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
BinaryBroadcastCompute<xpu, OP>(
BinaryBroadcastCompute<xpu, OP, allow_bool>(
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
}
} else {
Expand All @@ -317,6 +314,72 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
#endif
}

#ifndef _WIN32
template<typename xpu, typename OP, typename LOP, typename ROP>
#else
template<typename xpu, typename OP>
#endif
void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];

if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;

if (lhs.type_flag_ == rhs.type_flag_) {
BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
return;
}

#ifndef _WIN32
MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
#else
MixedBinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
#endif
}

#ifndef _WIN32
template<typename xpu, typename OP, typename LOP, typename ROP>
#else
template<typename xpu, typename OP>
#endif
void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

const TBlob& lhs = inputs[0];
const TBlob& rhs = inputs[1];
const TBlob& out = outputs[0];

if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;

if (lhs.type_flag_ == rhs.type_flag_) {
BinaryBroadcastComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs);
return;
}

#ifndef _WIN32
MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
#else
MixedBinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
#endif
}

template<typename xpu, typename LOP, typename ROP>
void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
12 changes: 8 additions & 4 deletions src/operator/operator_tune-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class OperatorTune : public OperatorTuneByType<DType> {
TuneAll();
}

~OperatorTune() {
delete[] data_set_;
}

/*!
* \brief Initialize the OperatorTune object
* \return Whether the OperatorTune object was successfully initialized
Expand All @@ -124,7 +128,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
if (!initialized_) {
initialized_ = true;
// Generate some random data for calling the operator kernels
data_set_.reserve(0x100);
data_set_ = reinterpret_cast<DType*>(new char[0x100 * sizeof(DType)]);
std::random_device rd;
std::mt19937 gen(rd());
if (!std::is_integral<DType>::value) {
Expand All @@ -136,7 +140,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
--n;
continue;
}
data_set_.emplace_back(val);
data_set_[n] = val;
}
} else {
std::uniform_int_distribution<> dis(-128, 127);
Expand All @@ -147,7 +151,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
--n;
continue;
}
data_set_.emplace_back(val);
data_set_[n] = val;
}
}
// Use this environment variable to generate new tuning statistics
Expand Down Expand Up @@ -517,7 +521,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
/*! \brief Number of passes to obtain an average */
static constexpr duration_t OUTSIDE_COUNT = (1 << OUTSIDE_COUNT_SHIFT);
/*! \brief Random data for timing operator calls */
static std::vector<DType> data_set_;
static DType* data_set_;
/*! \brief Operators tuned */
static std::unordered_set<std::string> operator_names_;
/*! \brief Arbitary object to modify in OMP loop */
Expand Down
10 changes: 5 additions & 5 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ double OperatorTuneBase::tuning_weight_scale_ = 0.0;
*/
#define IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(__typ$) \
template<> bool OperatorTune<__typ$>::initialized_ = false; \
template<> std::vector<__typ$> OperatorTune<__typ$>::data_set_ = {}; \
template<> __typ$* OperatorTune<__typ$>::data_set_ = nullptr; \
template<> volatile tune::TuningMode OperatorTuneByType<__typ$>::tuning_mode_ = tune::kAuto; \
template<> volatile int OperatorTune<__typ$>::volatile_int_ = 9; /* arbitrary number */ \
template<> std::unordered_set<std::string> OperatorTune<__typ$>::operator_names_({}); \
Expand Down Expand Up @@ -314,10 +314,10 @@ IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_logical_not);
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::nt); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::clip); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::clip); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::plus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mul); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::plus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::minus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::mul); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::div); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::true_divide); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus_sign); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus); // NOLINT()
Expand Down
Loading

0 comments on commit 4c0464d

Please sign in to comment.