Skip to content

Commit

Permalink
[Tensor] Update tensorbase for efficient creation of new tensor class.
Browse files Browse the repository at this point in the history
This PR updates the TensorBase class to make mathematical operations that are not required to create a new tensor class.
This change allows developers to easily create new classes without implementing math operations.
Note that these functions should be implemented to utilize tensor operations fully.

**Changes proposed in this PR:**
- Change math operation function from pure virtual function to virtual function
- Add a private function to get the data type as a string

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
  • Loading branch information
djeong20 authored and jijoongmoon committed Aug 5, 2024
1 parent 8877d6a commit 5873a4e
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 48 deletions.
232 changes: 232 additions & 0 deletions nntrainer/tensor/tensor_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,236 @@ void TensorBase::calculateFlattenDot(
ldc = (getFormat() == Tformat::NHWC) ? output.channel() : output.width();
}

/**
* Please note that the following functions need to be implemented in a child
* class to utilize tensor operations fully — operations such as addition,
* division, multiplication, dot production, data averaging, and so on.
*/
void TensorBase::setRandNormal(float mean, float stddev) {
throw std::invalid_argument(
"Tensor::setRandNormal() is currently not supported in tensor data type " +
getStringDataType());
}

void TensorBase::setRandUniform(float min, float max) {
throw std::invalid_argument(
"Tensor::setRandUniform() is currently not supported in tensor data type " +
getStringDataType());
}

void TensorBase::setRandBernoulli(float probability) {
throw std::invalid_argument("Tensor::setRandBernoulli() is currently not "
"supported in tensor data type " +
getStringDataType());
}

Tensor TensorBase::multiply_strided(Tensor const &m, Tensor &output,
const float beta) const {
throw std::invalid_argument("Tensor::multiply_strided() is currently not "
"supported in tensor data type " +
getStringDataType());
return output;
}

int TensorBase::multiply_i(float const &value) {
throw std::invalid_argument(
"Tensor::multiply_i() is currently not supported in tensor data type " +
getStringDataType());
return ML_ERROR_NOT_SUPPORTED;
}

Tensor &TensorBase::multiply(float const &value, Tensor &output) const {
throw std::invalid_argument(
"Tensor::multiply() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

Tensor &TensorBase::multiply(Tensor const &m, Tensor &output,
const float beta) const {
throw std::invalid_argument(
"Tensor::multiply() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

Tensor &TensorBase::divide(float const &value, Tensor &output) const {
throw std::invalid_argument(
"Tensor::divide() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

Tensor &TensorBase::divide(Tensor const &m, Tensor &output) const {
throw std::invalid_argument(
"Tensor::divide() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

Tensor &TensorBase::add_strided(Tensor const &input, Tensor &output,
const float beta) const {
throw std::invalid_argument(
"Tensor::add_strided() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

int TensorBase::add_i(Tensor const &m, Tensor &output, float const alpha) {
throw std::invalid_argument(
"Tensor::add_i() is currently not supported in tensor data type " +
getStringDataType());
return ML_ERROR_NOT_SUPPORTED;
}

int TensorBase::add_i_partial(unsigned int len, unsigned int addr_idx,
Tensor &m, unsigned int incX, unsigned int incY,
const Tensor alphas, unsigned int alpha_idx) {
throw std::invalid_argument(
"Tensor::add_i_partial() is currently not supported in tensor data type " +
getStringDataType());
return ML_ERROR_NOT_SUPPORTED;
}

Tensor &TensorBase::add(float const &value, Tensor &output) const {
throw std::invalid_argument(
"Tensor::add() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

Tensor &TensorBase::add(Tensor const &m, Tensor &output,
float const alpha) const {
throw std::invalid_argument(
"Tensor::add() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

Tensor &TensorBase::subtract(float const &value, Tensor &output) const {
throw std::invalid_argument(
"Tensor::subtract() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

void TensorBase::sum_by_batch(Tensor &output) const {
throw std::invalid_argument(
"Tensor::sum_by_batch() is currently not supported in tensor data type " +
getStringDataType());
}

Tensor &TensorBase::sum(unsigned int axis, Tensor &output, float alpha,
float beta) const {
throw std::invalid_argument(
"Tensor::sum() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

float TensorBase::l2norm() const {
throw std::invalid_argument(
"Tensor::l2norm() is currently not supported in tensor data type " +
getStringDataType());
return ML_ERROR_NOT_SUPPORTED;
}

Tensor &TensorBase::pow(float exponent, Tensor &output) const {
throw std::invalid_argument(
"Tensor::pow() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

Tensor &TensorBase::erf(Tensor &output) const {
throw std::invalid_argument(
"Tensor::erf() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

void TensorBase::sin(Tensor &out, float alpha) {
throw std::invalid_argument(
"Tensor::sin() is currently not supported in tensor data type " +
getStringDataType());
}

void TensorBase::cos(Tensor &out, float alpha) {
throw std::invalid_argument(
"Tensor::cos() is currently not supported in tensor data type " +
getStringDataType());
}

void TensorBase::inv_sqrt(Tensor &out) {
throw std::invalid_argument(
"Tensor::inv_sqrt() is currently not supported in tensor data type " +
getStringDataType());
}

Tensor &TensorBase::dot(Tensor const &input, Tensor &output, bool trans,
bool trans_in, float beta) const {
throw std::invalid_argument(
"Tensor::dot() is currently not supported in tensor data type " +
getStringDataType());
return output;
}

void TensorBase::dropout_mask(float dropout) {
throw std::invalid_argument(
"Tensor::dropout_mask() is currently not supported in tensor data type " +
getStringDataType());
}

void TensorBase::filter_mask(const Tensor &mask_len, bool reverse) {
throw std::invalid_argument(
"Tensor::filter_mask() is currently not supported in tensor data type " +
getStringDataType());
}

void TensorBase::zoneout_mask(Tensor &opposite, float zoneout) {
throw std::invalid_argument(
"Tensor::zoneout_mask() is currently not supported in tensor data type " +
getStringDataType());
}

std::vector<Tensor> TensorBase::split(std::vector<size_t> sizes, int axis) {
throw std::invalid_argument(
"Tensor::split() is currently not supported in tensor data type " +
getStringDataType());
std::vector<Tensor> ret;
return ret;
}

Tensor TensorBase::concat(const std::vector<Tensor> &tensors, int axis) {
throw std::invalid_argument(
"Tensor::concat() is currently not supported in tensor data type " +
getStringDataType());
return tensors[0];
}

Tensor &TensorBase::apply(std::function<float(float)> f, Tensor &output) const {
throw std::invalid_argument(
"Tensor::apply(std::function<float(float)> f, Tensor &output) is "
"not supported in tensor data type " +
getStringDataType());
return output;
}

#ifdef ENABLE_FP16
Tensor &TensorBase::apply(std::function<_FP16(_FP16)> f, Tensor &output) const {
throw std::invalid_argument(
"Tensor::apply(std::function<_FP16(_FP16)> f, Tensor &output) is "
"not supported in tensor data type " +
getStringDataType());
return output;
}
#endif

Tensor &TensorBase::transpose(const std::string &direction, Tensor &out) const {
throw std::invalid_argument(
"Tensor::transpose() is currently not supported in tensor data type " +
getStringDataType());
return out;
}

} // namespace nntrainer
Loading

0 comments on commit 5873a4e

Please sign in to comment.