From 5873a4e68a40d28225d1c5b293836bb62384552c Mon Sep 17 00:00:00 2001 From: Donghyeon Jeong Date: Tue, 30 Jul 2024 11:09:39 +0900 Subject: [PATCH] [Tensor] Update tensorbase for efficient creation of new tensor class. This PR updates the TensorBase class to make mathematical operations that are not required to create a new tensor class. This change allows developers to easily create new classes without implementing math operations. Note that these functions should be implemented to utilize tensor operations fully. **Changes proposed in this PR:** - Change math operation function from pure virtual function to virtual function - Add a private function to get the data type as a string **Self-evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghyeon Jeong --- nntrainer/tensor/tensor_base.cpp | 232 +++++++++++++++++++++++++++++++ nntrainer/tensor/tensor_base.h | 109 ++++++++------- 2 files changed, 293 insertions(+), 48 deletions(-) diff --git a/nntrainer/tensor/tensor_base.cpp b/nntrainer/tensor/tensor_base.cpp index ed34654d04..d982a4147e 100644 --- a/nntrainer/tensor/tensor_base.cpp +++ b/nntrainer/tensor/tensor_base.cpp @@ -342,4 +342,236 @@ void TensorBase::calculateFlattenDot( ldc = (getFormat() == Tformat::NHWC) ? output.channel() : output.width(); } +/** + * Please note that the following functions need to be implemented in a child + * class to utilize tensor operations fully — operations such as addition, + * division, multiplication, dot production, data averaging, and so on. + */ +void TensorBase::setRandNormal(float mean, float stddev) { + throw std::invalid_argument( + "Tensor::setRandNormal() is currently not supported in tensor data type " + + getStringDataType()); +} + +void TensorBase::setRandUniform(float min, float max) { + throw std::invalid_argument( + "Tensor::setRandUniform() is currently not supported in tensor data type " + + getStringDataType()); +} + +void TensorBase::setRandBernoulli(float probability) { + throw std::invalid_argument("Tensor::setRandBernoulli() is currently not " + "supported in tensor data type " + + getStringDataType()); +} + +Tensor TensorBase::multiply_strided(Tensor const &m, Tensor &output, + const float beta) const { + throw std::invalid_argument("Tensor::multiply_strided() is currently not " + "supported in tensor data type " + + getStringDataType()); + return output; +} + +int TensorBase::multiply_i(float const &value) { + throw std::invalid_argument( + "Tensor::multiply_i() is currently not supported in tensor data type " + + getStringDataType()); + return ML_ERROR_NOT_SUPPORTED; +} + +Tensor &TensorBase::multiply(float const &value, Tensor &output) const { + throw std::invalid_argument( + "Tensor::multiply() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +Tensor &TensorBase::multiply(Tensor const &m, Tensor &output, + const float beta) const { + throw std::invalid_argument( + "Tensor::multiply() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +Tensor &TensorBase::divide(float const &value, Tensor &output) const { + throw std::invalid_argument( + "Tensor::divide() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +Tensor &TensorBase::divide(Tensor const &m, Tensor &output) const { + throw std::invalid_argument( + "Tensor::divide() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +Tensor &TensorBase::add_strided(Tensor const &input, Tensor &output, + const float beta) const { + throw std::invalid_argument( + "Tensor::add_strided() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +int TensorBase::add_i(Tensor const &m, Tensor &output, float const alpha) { + throw std::invalid_argument( + "Tensor::add_i() is currently not supported in tensor data type " + + getStringDataType()); + return ML_ERROR_NOT_SUPPORTED; +} + +int TensorBase::add_i_partial(unsigned int len, unsigned int addr_idx, + Tensor &m, unsigned int incX, unsigned int incY, + const Tensor alphas, unsigned int alpha_idx) { + throw std::invalid_argument( + "Tensor::add_i_partial() is currently not supported in tensor data type " + + getStringDataType()); + return ML_ERROR_NOT_SUPPORTED; +} + +Tensor &TensorBase::add(float const &value, Tensor &output) const { + throw std::invalid_argument( + "Tensor::add() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +Tensor &TensorBase::add(Tensor const &m, Tensor &output, + float const alpha) const { + throw std::invalid_argument( + "Tensor::add() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +Tensor &TensorBase::subtract(float const &value, Tensor &output) const { + throw std::invalid_argument( + "Tensor::subtract() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +void TensorBase::sum_by_batch(Tensor &output) const { + throw std::invalid_argument( + "Tensor::sum_by_batch() is currently not supported in tensor data type " + + getStringDataType()); +} + +Tensor &TensorBase::sum(unsigned int axis, Tensor &output, float alpha, + float beta) const { + throw std::invalid_argument( + "Tensor::sum() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +float TensorBase::l2norm() const { + throw std::invalid_argument( + "Tensor::l2norm() is currently not supported in tensor data type " + + getStringDataType()); + return ML_ERROR_NOT_SUPPORTED; +} + +Tensor &TensorBase::pow(float exponent, Tensor &output) const { + throw std::invalid_argument( + "Tensor::pow() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +Tensor &TensorBase::erf(Tensor &output) const { + throw std::invalid_argument( + "Tensor::erf() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +void TensorBase::sin(Tensor &out, float alpha) { + throw std::invalid_argument( + "Tensor::sin() is currently not supported in tensor data type " + + getStringDataType()); +} + +void TensorBase::cos(Tensor &out, float alpha) { + throw std::invalid_argument( + "Tensor::cos() is currently not supported in tensor data type " + + getStringDataType()); +} + +void TensorBase::inv_sqrt(Tensor &out) { + throw std::invalid_argument( + "Tensor::inv_sqrt() is currently not supported in tensor data type " + + getStringDataType()); +} + +Tensor &TensorBase::dot(Tensor const &input, Tensor &output, bool trans, + bool trans_in, float beta) const { + throw std::invalid_argument( + "Tensor::dot() is currently not supported in tensor data type " + + getStringDataType()); + return output; +} + +void TensorBase::dropout_mask(float dropout) { + throw std::invalid_argument( + "Tensor::dropout_mask() is currently not supported in tensor data type " + + getStringDataType()); +} + +void TensorBase::filter_mask(const Tensor &mask_len, bool reverse) { + throw std::invalid_argument( + "Tensor::filter_mask() is currently not supported in tensor data type " + + getStringDataType()); +} + +void TensorBase::zoneout_mask(Tensor &opposite, float zoneout) { + throw std::invalid_argument( + "Tensor::zoneout_mask() is currently not supported in tensor data type " + + getStringDataType()); +} + +std::vector TensorBase::split(std::vector sizes, int axis) { + throw std::invalid_argument( + "Tensor::split() is currently not supported in tensor data type " + + getStringDataType()); + std::vector ret; + return ret; +} + +Tensor TensorBase::concat(const std::vector &tensors, int axis) { + throw std::invalid_argument( + "Tensor::concat() is currently not supported in tensor data type " + + getStringDataType()); + return tensors[0]; +} + +Tensor &TensorBase::apply(std::function f, Tensor &output) const { + throw std::invalid_argument( + "Tensor::apply(std::function f, Tensor &output) is " + "not supported in tensor data type " + + getStringDataType()); + return output; +} + +#ifdef ENABLE_FP16 +Tensor &TensorBase::apply(std::function<_FP16(_FP16)> f, Tensor &output) const { + throw std::invalid_argument( + "Tensor::apply(std::function<_FP16(_FP16)> f, Tensor &output) is " + "not supported in tensor data type " + + getStringDataType()); + return output; +} +#endif + +Tensor &TensorBase::transpose(const std::string &direction, Tensor &out) const { + throw std::invalid_argument( + "Tensor::transpose() is currently not supported in tensor data type " + + getStringDataType()); + return out; +} + } // namespace nntrainer diff --git a/nntrainer/tensor/tensor_base.h b/nntrainer/tensor/tensor_base.h index 2cded86154..576ed5db1f 100644 --- a/nntrainer/tensor/tensor_base.h +++ b/nntrainer/tensor/tensor_base.h @@ -225,17 +225,17 @@ class TensorBase { /** * @copydoc Tensor::setRandNormal() */ - virtual void setRandNormal(float mean, float stddev) = 0; + virtual void setRandNormal(float mean, float stddev); /** * @copydoc Tensor::setRandBernoulli() */ - virtual void setRandUniform(float min, float max) = 0; + virtual void setRandUniform(float min, float max); /** * @copydoc Tensor::setRandBernoulli() */ - virtual void setRandBernoulli(float probability) = 0; + virtual void setRandBernoulli(float probability); /** * @copydoc Tensor::initialize() @@ -252,125 +252,115 @@ class TensorBase { * const float beta) */ virtual Tensor multiply_strided(Tensor const &m, Tensor &output, - const float beta) const = 0; + const float beta) const; /** * @copydoc Tensor::multiply_i(float const &value) */ - virtual int multiply_i(float const &value) = 0; + virtual int multiply_i(float const &value); /** - * @copydoc Tensor::multiply(float const &value, Tensor &out) + * @copydoc Tensor::multiply(float const &value, Tensor &output) */ - virtual Tensor &multiply(float const &value, Tensor &out) const = 0; + virtual Tensor &multiply(float const &value, Tensor &output) const; /** * @copydoc Tensor::multiply(Tensor const &m, Tensor &output, const * float beta = 0.0) */ virtual Tensor &multiply(Tensor const &m, Tensor &output, - const float beta = 0.0) const = 0; + const float beta = 0.0) const; /** * @copydoc Tensor::divide(float const &value, Tensor &output) */ - virtual Tensor ÷(float const &value, Tensor &output) const = 0; + virtual Tensor ÷(float const &value, Tensor &output) const; /** * @copydoc Tensor::divide(Tensor const &m, Tensor &output) */ - virtual Tensor ÷(Tensor const &m, Tensor &output) const = 0; + virtual Tensor ÷(Tensor const &m, Tensor &output) const; /** * @copydoc Tensor::add_strided(Tensor const &input, Tensor &output, * const float beta) */ virtual Tensor &add_strided(Tensor const &input, Tensor &output, - const float beta) const = 0; + const float beta) const; /** * @copydoc Tensor::add_i(Tensor const &m, float const alpha) */ - virtual int add_i(Tensor const &m, Tensor &output, float const alpha) = 0; + virtual int add_i(Tensor const &m, Tensor &output, float const alpha); /** * @copydoc Tensor::add_i_partial() */ virtual int add_i_partial(unsigned int len, unsigned int addr_idx, Tensor &m, unsigned int incX, unsigned int incY, - const Tensor alphas, unsigned int alpha_idx) = 0; + const Tensor alphas, unsigned int alpha_idx); /** * @copydoc Tensor::add(float const &value, Tensor &output) */ - virtual Tensor &add(float const &value, Tensor &output) const = 0; + virtual Tensor &add(float const &value, Tensor &output) const; /** * @copydoc Tensor::add(Tensor const &m, Tensor &output, float const * alpha) */ - virtual Tensor &add(Tensor const &m, Tensor &output, - float const alpha) const = 0; + virtual Tensor &add(Tensor const &m, Tensor &output, float const alpha) const; /** * @copydoc Tensor::subtract(float const &value, Tensor &output) */ - virtual Tensor &subtract(float const &value, Tensor &output) const = 0; + virtual Tensor &subtract(float const &value, Tensor &output) const; /** * @brief Sum all the Tensor elements according to the batch * @param[out] output Tensor(batch, 1, 1, 1) */ - virtual void sum_by_batch(Tensor &output) const = 0; + virtual void sum_by_batch(Tensor &output) const; /** * @copydoc Tensor::sum(unsigned int axis, Tensor &output, float alpha, * float beta) const */ virtual Tensor &sum(unsigned int axis, Tensor &output, float alpha, - float beta) const = 0; + float beta) const; /** * @copydoc Tensor::l2norm */ - virtual float l2norm() const = 0; + virtual float l2norm() const; /** * @copydoc Tensor::pow(float exponent, Tensor &output) */ - virtual Tensor &pow(float exponent, Tensor &output) const = 0; + virtual Tensor &pow(float exponent, Tensor &output) const; /** * @copydoc Tensor::erf(Tensor &output) */ - virtual Tensor &erf(Tensor &output) const = 0; + virtual Tensor &erf(Tensor &output) const; /** * @brief sin transform function * @param[out] out out to store the result */ - virtual void sin(Tensor &out, float alpha = 1.0) { - throw std::invalid_argument( - "Tensor::sin not supported in current tensor data type."); - } + virtual void sin(Tensor &out, float alpha = 1.0); /** * @brief cos transform function * @param[out] out out to store the result */ - virtual void cos(Tensor &out, float alpha = 1.0) { - throw std::invalid_argument( - "Tensor::cos not supported in current tensor data type."); - } + virtual void cos(Tensor &out, float alpha = 1.0); /** * @brief inverse squared root function * @param[out] out out to store the result */ - virtual void inv_sqrt(Tensor &out) { - throw std::invalid_argument( - "Tensor::inv_sqrt not supported in current tensor data type."); - } + virtual void inv_sqrt(Tensor &out); /** * @brief Dot Product of Tensor ( equal MxM ) @@ -384,32 +374,32 @@ class TensorBase { * @retval Calculated Tensor */ virtual Tensor &dot(Tensor const &input, Tensor &output, bool trans, - bool trans_in, float beta) const = 0; + bool trans_in, float beta) const; /** * @copydoc Tensor::dropout_mask(float dropout) */ - virtual void dropout_mask(float dropout) = 0; + virtual void dropout_mask(float dropout); /** * @copydoc Tensor::filter_mask(const Tensor &mask_len, bool reverse) */ - virtual void filter_mask(const Tensor &mask_len, bool reverse) = 0; + virtual void filter_mask(const Tensor &mask_len, bool reverse); /** * @copydoc Tensor::zoneout_mask(Tensor &opposite, float zoneout) */ - virtual void zoneout_mask(Tensor &opposite, float zoneout) = 0; + virtual void zoneout_mask(Tensor &opposite, float zoneout); /** * @copydoc Tensor::split(std::vector sizes, int axis) */ - virtual std::vector split(std::vector sizes, int axis) = 0; + virtual std::vector split(std::vector sizes, int axis); /** * @copydoc Tensor::concat(const std::vector &tensors, int axis) */ - virtual Tensor concat(const std::vector &tensors, int axis) = 0; + virtual Tensor concat(const std::vector &tensors, int axis); /** * @copydoc Tensor::print(std::ostream &out) @@ -418,18 +408,16 @@ class TensorBase { /** * @copydoc Tensor::apply(std::function f, Tensor &output) + * @note This will be only used in FloatTensor. */ - virtual Tensor &apply(std::function f, Tensor &output) const { - return output; - } + virtual Tensor &apply(std::function f, Tensor &output) const; #ifdef ENABLE_FP16 /** * @copydoc Tensor::apply(std::function f, Tensor &output) + * @note This will be only used in HalfTensor. */ - virtual Tensor &apply(std::function<_FP16(_FP16)> f, Tensor &output) const { - return output; - } + virtual Tensor &apply(std::function<_FP16(_FP16)> f, Tensor &output) const; #endif /** @@ -476,8 +464,7 @@ class TensorBase { /** * @copydoc Tensor::transpose(const std::string &direction, Tensor &out) */ - virtual Tensor &transpose(const std::string &direction, - Tensor &out) const = 0; + virtual Tensor &transpose(const std::string &direction, Tensor &out) const; /** * @brief put data of Tensor @@ -744,6 +731,32 @@ class TensorBase { unsigned int &input_last_axis, unsigned int &M, unsigned int &N, unsigned int &K, unsigned int &lda, unsigned int &ldb, unsigned int &ldc) const; + + /** + * @brief Get the Data Type String object + * @return std::string of tensor data type + */ + std::string getStringDataType() const { + std::string res; + switch (getDataType()) { + case Tdatatype::FP32: + res = "FP32"; + break; + case Tdatatype::FP16: + res = "FP16"; + break; + case Tdatatype::QINT8: + res = "QINT8"; + break; + case Tdatatype::QINT4: + res = "QINT4"; + break; + default: + res = "Undefined type"; + break; + } + return res; + } }; /**