diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index b8db165675a0..d563f2516341 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -97,6 +97,18 @@ using std::is_integral; } \ } +#define MXNET_BINARY_MATH_OP_NC_WITH_BOOL(name, expr) \ + struct name : public mxnet_op::tunable { \ + template::value, int>::type = 0> \ + MSHADOW_XINLINE static DType Map(DType a, DType b) { \ + return (expr); \ + } \ + MSHADOW_XINLINE static bool Map(bool a, bool b) { \ + return (expr); \ + } \ + } + #define MXNET_BINARY_LOGIC_OP_NC(name, expr) \ struct name : public mxnet_op::tunable { \ template \ @@ -192,8 +204,6 @@ MXNET_BINARY_MATH_OP_NC(left, a); MXNET_BINARY_MATH_OP_NC(right, b); -MXNET_BINARY_MATH_OP_NC(mul, a * b); - #ifndef _WIN32 struct mixed_plus { template >{{0, 0}, {1, 0}}; \ }) \ - .set_attr("FResourceRequest", \ - [](const NodeAttrs& attrs) { \ - return std::vector{ResourceRequest::kTempSpace}; \ - }) \ .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") #else @@ -106,6 +102,10 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}, {1, 0}}; \ }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") #endif @@ -114,13 +114,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add) #ifndef _WIN32 .set_attr( "FCompute", - MixedBinaryBroadcastCompute) + NumpyBinaryBroadcastComputeWithBool) #else .set_attr( "FCompute", - MixedBinaryBroadcastCompute) + NumpyBinaryBroadcastComputeWithBool) #endif .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); @@ -128,13 +127,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) #ifndef _WIN32 .set_attr( "FCompute", - MixedBinaryBroadcastCompute) #else .set_attr( "FCompute", - MixedBinaryBroadcastCompute) + NumpyBinaryBroadcastCompute) #endif .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); @@ -142,13 +140,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) #ifndef _WIN32 .set_attr( "FCompute", - MixedBinaryBroadcastCompute) + NumpyBinaryBroadcastComputeWithBool) #else .set_attr( "FCompute", - MixedBinaryBroadcastCompute) + NumpyBinaryBroadcastComputeWithBool) #endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 153ffd0048dd..a0a277df211f 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -32,39 +32,36 @@ NNVM_REGISTER_OP(_npi_add) #ifndef _WIN32 .set_attr( "FCompute", - MixedBinaryBroadcastCompute); + NumpyBinaryBroadcastComputeWithBool); #else .set_attr( "FCompute", - MixedBinaryBroadcastCompute); + NumpyBinaryBroadcastComputeWithBool); #endif NNVM_REGISTER_OP(_npi_subtract) #ifndef _WIN32 .set_attr( "FCompute", - MixedBinaryBroadcastCompute); #else .set_attr( "FCompute", - MixedBinaryBroadcastCompute); + NumpyBinaryBroadcastCompute); #endif NNVM_REGISTER_OP(_npi_multiply) #ifndef _WIN32 .set_attr( "FCompute", - MixedBinaryBroadcastCompute); + NumpyBinaryBroadcastComputeWithBool); #else .set_attr( "FCompute", - MixedBinaryBroadcastCompute); + NumpyBinaryBroadcastComputeWithBool); #endif NNVM_REGISTER_OP(_npi_mod) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index 1d36c6ff881e..caa733a8304d 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -34,8 +34,8 @@ namespace mxnet { namespace op { -inline void PrintErrorMessage(const std::string& name, const int dtype1, const int dtype2) { - LOG(FATAL) << "Operator " << name << " does not support combination of " +inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) { + LOG(FATAL) << "Operator " << op_name << " does not support combination of " << common::dtype_string(dtype1) << " with " << common::dtype_string(dtype2) << " yet..."; } @@ -218,7 +218,11 @@ void MixedAllRealBinaryBroadcastCompute(const std::string& op_name, } #endif +#ifndef _WIN32 template +#else +template +#endif void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -233,13 +237,6 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; - if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return; - - if (lhs.type_flag_ == rhs.type_flag_) { - BinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); - return; - } - #ifndef _WIN32 mxnet::TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, @@ -299,7 +296,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, temp_tblob = TBlob(temp_tensor); }); CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( + BinaryBroadcastCompute( attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); } else { MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { @@ -308,7 +305,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, temp_tblob = TBlob(temp_tensor); }); CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( + BinaryBroadcastCompute( attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); } } else { @@ -317,6 +314,72 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, #endif } +#ifndef _WIN32 +template +#else +template +#endif +void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return; + + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); + return; + } + +#ifndef _WIN32 + MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); +#else + MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); +#endif +} + +#ifndef _WIN32 +template +#else +template +#endif +void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return; + + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastComputeWithBool(attrs, ctx, inputs, req, outputs); + return; + } + +#ifndef _WIN32 + MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); +#else + MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); +#endif +} + template void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/operator_tune-inl.h b/src/operator/operator_tune-inl.h index 1dbcf4298918..da02c1c2b93b 100644 --- a/src/operator/operator_tune-inl.h +++ b/src/operator/operator_tune-inl.h @@ -116,6 +116,10 @@ class OperatorTune : public OperatorTuneByType { TuneAll(); } + ~OperatorTune() { + delete[] data_set_; + } + /*! * \brief Initialize the OperatorTune object * \return Whether the OperatorTune object was successfully initialized @@ -124,7 +128,7 @@ class OperatorTune : public OperatorTuneByType { if (!initialized_) { initialized_ = true; // Generate some random data for calling the operator kernels - data_set_.reserve(0x100); + data_set_ = reinterpret_cast(new char[0x100 * sizeof(DType)]); std::random_device rd; std::mt19937 gen(rd()); if (!std::is_integral::value) { @@ -136,7 +140,7 @@ class OperatorTune : public OperatorTuneByType { --n; continue; } - data_set_.emplace_back(val); + data_set_[n] = val; } } else { std::uniform_int_distribution<> dis(-128, 127); @@ -147,7 +151,7 @@ class OperatorTune : public OperatorTuneByType { --n; continue; } - data_set_.emplace_back(val); + data_set_[n] = val; } } // Use this environment variable to generate new tuning statistics @@ -517,7 +521,7 @@ class OperatorTune : public OperatorTuneByType { /*! \brief Number of passes to obtain an average */ static constexpr duration_t OUTSIDE_COUNT = (1 << OUTSIDE_COUNT_SHIFT); /*! \brief Random data for timing operator calls */ - static std::vector data_set_; + static DType* data_set_; /*! \brief Operators tuned */ static std::unordered_set operator_names_; /*! \brief Arbitary object to modify in OMP loop */ diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index d0642eedfdf3..2d59c1d05ae6 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -39,7 +39,7 @@ double OperatorTuneBase::tuning_weight_scale_ = 0.0; */ #define IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(__typ$) \ template<> bool OperatorTune<__typ$>::initialized_ = false; \ - template<> std::vector<__typ$> OperatorTune<__typ$>::data_set_ = {}; \ + template<> __typ$* OperatorTune<__typ$>::data_set_ = nullptr; \ template<> volatile tune::TuningMode OperatorTuneByType<__typ$>::tuning_mode_ = tune::kAuto; \ template<> volatile int OperatorTune<__typ$>::volatile_int_ = 9; /* arbitrary number */ \ template<> std::unordered_set OperatorTune<__typ$>::operator_names_({}); \ @@ -314,10 +314,10 @@ IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_logical_not); IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::nt); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::clip); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::clip); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::plus); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mul); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::plus); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::minus); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::mul); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::div); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::true_divide); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus_sign); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus); // NOLINT() diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index ad06df8d92be..b48ed389ba98 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -347,6 +347,9 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } else { if (req[0] != kNullOp) { mshadow::Stream *s = ctx.get_stream(); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; + } MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); @@ -361,6 +364,35 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } } +template +void BinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (outputs[0].shape_.Size() == 0U) return; + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryOp::ComputeWithBool(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] != kNullOp) { + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); + }); + }); + } + } +} + template void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index da088c1dcc39..c046a28f16b2 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -485,6 +485,9 @@ class ElemwiseBinaryOp : public OpBase { Stream *s = ctx.get_stream(); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; + } MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) @@ -499,6 +502,31 @@ class ElemwiseBinaryOp : public OpBase { } } + template + static void ComputeWithBool(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + if (req[0] != kNullOp) { + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } + }); + }); + } + } + template static void ComputeLogic(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 32cd5b10717e..33927199d27b 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1691,6 +1691,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'subtract': (-1.0, 1.0), 'multiply': (-1.0, 1.0), } + shape_pairs = [((3, 2), (3, 2)), ((3, 2), (3, 1)), ((3, 1), (3, 0)), @@ -1698,6 +1699,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): ((2, 3, 4), (3, 1)), ((2, 3), ()), ((), (2, 3))] + itypes = [np.bool, np.int8, np.int32, np.int64] ftypes = [np.float16, np.float32, np.float64] for func, func_data in funcs.items(): @@ -1713,6 +1715,60 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2) +@with_seed() +@use_np +def test_np_boolean_binary_funcs(): + def check_boolean_binary_func(func, mx_x1, mx_x2): + class TestBooleanBinary(HybridBlock): + def __init__(self, func): + super(TestBooleanBinary, self).__init__() + self._func = func + + def hybrid_forward(self, F, a, b, *args, **kwargs): + return getattr(F.np, self._func)(a, b) + + np_x1 = mx_x1.asnumpy() + np_x2 = mx_x2.asnumpy() + np_func = getattr(_np, func) + mx_func = TestBooleanBinary(func) + for hybridize in [True, False]: + if hybridize: + mx_func.hybridize() + np_out = np_func(np_x1, np_x2) + with mx.autograd.record(): + y = mx_func(mx_x1, mx_x2) + assert y.shape == np_out.shape + assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=1e-3, atol=1e-20, + use_broadcast=False, equal_nan=True) + + np_out = getattr(_np, func)(np_x1, np_x2) + mx_out = getattr(mx.np, func)(mx_x1, mx_x2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=1e-3, atol=1e-20, + use_broadcast=False, equal_nan=True) + + + funcs = [ + 'add', + 'multiply', + 'true_divide', + ] + + shape_pairs = [((3, 2), (3, 2)), + ((3, 2), (3, 1)), + ((3, 1), (3, 0)), + ((0, 2), (1, 2)), + ((2, 3, 4), (3, 1)), + ((2, 3), ()), + ((), (2, 3))] + + for lshape, rshape in shape_pairs: + for func in funcs: + x1 = np.random.uniform(size=lshape) > 0.5 + x2 = np.random.uniform(size=rshape) > 0.5 + check_boolean_binary_func(func, x1, x2) + + @with_seed() @use_np def test_npx_relu():