Skip to content

Commit

Permalink
[refactor] Restructure getStringDataType function
Browse files Browse the repository at this point in the history
This patch updates the getStringDataType function structure to utilize method overriding.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
  • Loading branch information
djeong20 authored and jijoongmoon committed Aug 5, 2024
1 parent 5873a4e commit 32d901c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 21 deletions.
6 changes: 6 additions & 0 deletions nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,12 @@ class FloatTensor : public TensorBase {
const float *, float *)>
v_func,
Tensor &output) const;

/**
* @brief Get the Data Type String object
* @return std::string of tensor data type (FP32)
*/
std::string getStringDataType() const override { return "FP32"; }
};

} // namespace nntrainer
Expand Down
6 changes: 6 additions & 0 deletions nntrainer/tensor/half_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,12 @@ class HalfTensor : public TensorBase {
const _FP16 *, _FP16 *)>
v_func,
Tensor &output) const;

/**
* @brief Get the Data Type String object
* @return std::string of tensor data type (FP16)
*/
std::string getStringDataType() const override { return "FP16"; }
};

} // namespace nntrainer
Expand Down
24 changes: 3 additions & 21 deletions nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,28 +735,10 @@ class TensorBase {
/**
* @brief Get the Data Type String object
* @return std::string of tensor data type
* @note TensorBase::getStringDataType() should not be called. Please define
* this function in the derived class to the corresponding data type.
*/
std::string getStringDataType() const {
std::string res;
switch (getDataType()) {
case Tdatatype::FP32:
res = "FP32";
break;
case Tdatatype::FP16:
res = "FP16";
break;
case Tdatatype::QINT8:
res = "QINT8";
break;
case Tdatatype::QINT4:
res = "QINT4";
break;
default:
res = "Undefined type";
break;
}
return res;
}
virtual std::string getStringDataType() const { return "Undefined type"; }
};

/**
Expand Down

0 comments on commit 32d901c

Please sign in to comment.