Skip to content

Commit

Permalink
Guard overflow and data loss with SafeInt
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jan 12, 2024
1 parent cd5e90e commit d39e5d0
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cpu/ml/label_encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/ml/ml_common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/common/safeint.h"

namespace onnxruntime {
namespace ml {
Expand Down Expand Up @@ -116,10 +117,11 @@ std::vector<T> GetAttribute(const OpKernelInfo& info, const std::string& name, c
} else {
ORT_ENFORCE(result.IsOK(), "LabelEncoder is missing attribute ", tensor_name, " or ", name);
}
size_t tensor_size = 1;
SafeInt<int64_t> element_count = 1;
for (auto dim : attr_tensor_proto.dims()) {
tensor_size *= dim;
element_count *= dim;
}
const SafeInt<size_t> tensor_size(element_count);
std::vector<T> out(tensor_size);
result = utils::UnpackTensor<T>(attr_tensor_proto, Path(), out.data(), tensor_size);
ORT_ENFORCE(result.IsOK(), "LabelEncoder could not unpack tensor attribute ", name);
Expand Down

0 comments on commit d39e5d0

Please sign in to comment.