Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add resnet cpp handler #2514

Open
wants to merge 11 commits into
base: cpp_backend
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions cpp/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ function install_dependencies_linux() {
autoconf \
automake \
git \
cmake \
m4 \
g++ \
flex \
Expand Down Expand Up @@ -175,6 +174,14 @@ function install_libtorch() {
wget https://download.pytorch.org/libtorch/cu116/libtorch-cxx11-abi-shared-with-deps-1.12.1%2Bcu116.zip
unzip libtorch-cxx11-abi-shared-with-deps-1.12.1+cu116.zip
rm libtorch-cxx11-abi-shared-with-deps-1.12.1+cu116.zip
elif [ "$CUDA" = "cu117" ]; then
wget https://download.pytorch.org/libtorch/cu117/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcu117.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117.zip
rm libtorch-cxx11-abi-shared-with-deps-2.0.1+cu117.zip
elif [ "$CUDA" = "cu118" ]; then
wget https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.1%2Bcu118.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.0.1+cu118.zip
rm libtorch-cxx11-abi-shared-with-deps-2.0.1+cu118.zip
else
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.12.1%2Bcpu.zip
unzip libtorch-cxx11-abi-shared-with-deps-1.12.1+cpu.zip
Expand Down Expand Up @@ -254,7 +261,7 @@ function build() {
find $FOLLY_CMAKE_DIR -name "lib*.*" -exec ln -s "{}" $LIBS_DIR/ \;
if [ "$PLATFORM" = "Linux" ]; then
cmake \
-DCMAKE_PREFIX_PATH="$DEPS_DIR;$FOLLY_CMAKE_DIR;$YAML_CPP_CMAKE_DIR;$DEPS_DIR/libtorch" \
-DCMAKE_PREFIX_PATH="$DEPS_DIR;$FOLLY_CMAKE_DIR;$YAML_CPP_CMAKE_DIR;$DEPS_DIR/libtorch;" \
-DCMAKE_INSTALL_PREFIX="$PREFIX" \
"$MAYBE_BUILD_QUIC" \
"$MAYBE_BUILD_TESTS" \
Expand All @@ -265,7 +272,7 @@ function build() {
"$MAYBE_CUDA_COMPILER" \
..

if [ "$CUDA" = "cu102" ] || [ "$CUDA" = "cu113" ] || [ "$CUDA" = "cu116" ]; then
if [ "$CUDA" = "cu102" ] || [ "$CUDA" = "cu113" ] || [ "$CUDA" = "cu116" ] || [ "$CUDA" = "cu117" ] || [ "$CUDA" = "cu118" ]; then
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/bin/nvcc
fi
elif [ "$PLATFORM" = "Mac" ]; then
Expand Down Expand Up @@ -299,6 +306,10 @@ function build() {
mv $DEPS_DIR/../src/examples/libmnist_handler.so $DEPS_DIR/../../test/resources/torchscript_model/mnist/mnist_handler/libmnist_handler.so
fi

if [ -f "$DEPS_DIR/../src/examples/libresnet-18_handler.so" ]; then
mv $DEPS_DIR/../src/examples/libresnet-18_handler.so $DEPS_DIR/../../test/resources/torchscript_model/resnet-18/resnet-18_handler/libresnet-18_handler.so
fi

cd $DEPS_DIR/../..
if [ -f "$DEPS_DIR/../test/torchserve_cpp_test" ]; then
$DEPS_DIR/../test/torchserve_cpp_test
Expand Down Expand Up @@ -329,7 +340,7 @@ INSTALL_DEPENDENCIES=false
PREFIX=""
COMPILER_FLAGS=""
CUDA=""
USAGE="./build.sh [-j num_jobs] [-g cu102|cu113|cu116] [-q|--with-quic] [--install-dependencies] [-p|--prefix] [-x|--compiler-flags]"
USAGE="./build.sh [-j num_jobs] [-g cu102|cu113|cu116|cu117|cu118] [-q|--with-quic] [--install-dependencies] [-p|--prefix] [-x|--compiler-flags]"
while [ "$1" != "" ]; do
case $1 in
-j | --jobs ) shift
Expand Down
17 changes: 16 additions & 1 deletion cpp/src/examples/CMakeLists.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to you create a local CMakeLists.txt in the image_classifier folder and use add_subfolder().

Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,19 @@ set(MNIST_SOURCE_FILES "")
list(APPEND MNIST_SOURCE_FILES ${MNIST_SRC_DIR}/mnist_handler.cc)
add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES})
target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR})
target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES})
target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES})


set(RESNET_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/image_classifier/resnet-18")

set(RESNET_SOURCE_FILES "")
set(OPENCV_DIR "/usr/local/include/opencv4")
list(APPEND RESNET_SOURCE_FILES ${RESNET_SRC_DIR}/resnet-18_handler.cc)
add_library(resnet-18_handler SHARED ${RESNET_SOURCE_FILES})
target_include_directories(resnet-18_handler PUBLIC ${OPENCV_DIR})
target_include_directories(resnet-18_handler PUBLIC ${RESNET_SRC_DIR})
target_link_libraries(resnet-18_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES})
target_link_libraries(resnet-18_handler PRIVATE "/usr/local/lib/libopencv_imgcodecs.so")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better to use find_package and use the parameters set there instead of absolute paths.

target_link_libraries(resnet-18_handler PRIVATE "/usr/local/lib/libopencv_cudawarping.so")
target_link_libraries(resnet-18_handler PRIVATE "/usr/local/lib/libopencv_cudaimgproc.so")
target_link_libraries(resnet-18_handler PRIVATE "/usr/local/lib/libopencv_core.so")
261 changes: 261 additions & 0 deletions cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
#include "src/examples/image_classifier/resnet-18/resnet-18_handler.hh"

#include <folly/json.h>

#include <fstream>
#include <opencv2/cudaimgproc.hpp>
#include <opencv2/cudawarping.hpp>
#include <opencv2/opencv.hpp>

namespace resnet {

constexpr int kTargetImageSize = 224;
constexpr double kImageNormalizationMeanR = 0.485;
constexpr double kImageNormalizationMeanG = 0.456;
constexpr double kImageNormalizationMeanB = 0.406;
constexpr double kImageNormalizationStdR = 0.229;
constexpr double kImageNormalizationStdG = 0.224;
constexpr double kImageNormalizationStdB = 0.225;
constexpr int kTopKClasses = 5;

std::vector<torch::jit::IValue> ResnetHandler::Preprocess(
std::shared_ptr<torch::Device>& device,
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceRequestBatch>& request_batch,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
std::vector<torch::jit::IValue> batch_ivalue;
std::vector<torch::Tensor> batch_tensors;
uint8_t idx = 0;
for (auto& request : *request_batch) {
(*response_batch)[request.request_id] =
std::make_shared<torchserve::InferenceResponse>(request.request_id);
idx_to_req_id.first += idx_to_req_id.first.empty()
? request.request_id
: "," + request.request_id;
auto data_it =
request.parameters.find(torchserve::PayloadType::kPARAMETER_NAME_DATA);
auto dtype_it =
request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE);
if (data_it == request.parameters.end()) {
data_it = request.parameters.find(
torchserve::PayloadType::kPARAMETER_NAME_BODY);
dtype_it =
request.headers.find(torchserve::PayloadType::kHEADER_NAME_BODY_TYPE);
}

if (data_it == request.parameters.end() ||
dtype_it == request.headers.end()) {
TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id);
(*response_batch)[request.request_id]->SetResponse(
500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT,
"Empty payload");
continue;
}
/*
case2: the image is sent as string of bytesarray
if (dtype_it->second == "String") {
try {
auto b64decoded_str = folly::base64Decode(data_it->second);
torchserve::Converter::StrToBytes(b64decoded_str, image);
} catch (folly::base64_decode_error e) {
TS_LOGF(ERROR, "Failed to base64Decode for request id: {}, error: {}",
request.request_id,
e.what());
}
}
*/

try {
if (dtype_it->second == torchserve::PayloadType::kDATA_TYPE_BYTES) {
// case2: the image is sent as bytesarray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these commented out lines still needed?

// torch::serialize::InputArchive archive;
// archive.load_from(std::istringstream
// iss(std::string(data_it->second)));
/*
std::istringstream iss(std::string(data_it->second.begin(),
data_it->second.end())); torch::serialize::InputArchive archive;
images.emplace_back(archive.load_from(iss, torch::Device device);

std::vector<char> bytes(
static_cast<char>(*data_it->second.begin()),
static_cast<char>(*data_it->second.end()));

images.emplace_back(torch::pickle_load(bytes).toTensor().to(*device));
*/

cv::Mat image = cv::imdecode(data_it->second, cv::IMREAD_COLOR);

// Check if the image was successfully decoded
if (image.empty()) {
std::cerr << "Failed to decode the image.\n";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we continue in case the image was unsuccessfully loaded? What happens with the code below if image is empty?

}

const int rows = image.rows;
const int cols = image.cols;

const int cropSize = std::min(rows, cols);
const int offsetW = (cols - cropSize) / 2;
const int offsetH = (rows - cropSize) / 2;

const cv::Rect roi(offsetW, offsetH, cropSize, cropSize);
image = image(roi);

// Convert the image to GPU Mat
cv::cuda::GpuMat gpuImage;
cv::Mat resultImage;

gpuImage.upload(image);

// Resize on GPU
cv::cuda::resize(gpuImage, gpuImage,
cv::Size(kTargetImageSize, kTargetImageSize));

// Convert to BGR on GPU
cv::cuda::cvtColor(gpuImage, gpuImage, cv::COLOR_BGR2RGB);

// Convert to float on GPU
gpuImage.convertTo(gpuImage, CV_32FC3, 1 / 255.0);

// Download the final image from GPU to CPU
gpuImage.download(resultImage);

// Create a tensor from the CPU Mat
torch::Tensor tensorImage = torch::from_blob(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to create the tensor on GPU? Avoiding the gpuImage.download?

resultImage.data, {resultImage.rows, resultImage.cols, 3},
torch::kFloat);
tensorImage = tensorImage.permute({2, 0, 1});

std::vector<double> norm_mean = {kImageNormalizationMeanR,
kImageNormalizationMeanG,
kImageNormalizationMeanB};
std::vector<double> norm_std = {kImageNormalizationStdR,
kImageNormalizationStdG,
kImageNormalizationStdB};

// Normalize the tensor
tensorImage = torch::data::transforms::Normalize<>(
norm_mean, norm_std)(tensorImage);

tensorImage.clone();
batch_tensors.emplace_back(tensorImage.to(*device));
idx_to_req_id.second[idx++] = request.request_id;
} else if (dtype_it->second == "List") {
// case3: the image is a list
}
} catch (const std::runtime_error& e) {
TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}",
request.request_id, e.what());
auto response = (*response_batch)[request.request_id];
response->SetResponse(500, "data_type",
torchserve::PayloadType::kDATA_TYPE_STRING,
"runtime_error, failed to load tensor");
} catch (const c10::Error& e) {
TS_LOGF(ERROR, "Failed to load tensor for request id: {}, c10 error: {}",
request.request_id, e.msg());
auto response = (*response_batch)[request.request_id];
response->SetResponse(500, "data_type",
torchserve::PayloadType::kDATA_TYPE_STRING,
"c10 error, failed to load tensor");
}
}
if (!batch_tensors.empty()) {
batch_ivalue.emplace_back(torch::stack(batch_tensors).to(*device));
}

return batch_ivalue;
}

void ResnetHandler::Postprocess(
const torch::Tensor& data,
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
std::ifstream jsonFile("index_to_name.json");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move loading of the file into the constructor to avoid reloading every time we do postprocessing?

if (!jsonFile.is_open()) {
std::cerr << "Failed to open JSON file.\n";
}
std::string jsonString((std::istreambuf_iterator<char>(jsonFile)),
std::istreambuf_iterator<char>());
jsonFile.close();
folly::dynamic parsedJson = folly::parseJson(jsonString);
if (!parsedJson.isObject()) {
std::cerr << "Invalid JSON format.\n";
}
for (const auto& kv : idx_to_req_id.second) {
try {
auto response = (*response_batch)[kv.second];
namespace F = torch::nn::functional;

// Perform softmax and top-k operations
torch::Tensor ps = F::softmax(data, F::SoftmaxFuncOptions(1));
std::tuple<torch::Tensor, torch::Tensor> result =
torch::topk(ps, kTopKClasses, 1, true, true);
torch::Tensor probs = std::get<0>(result);
torch::Tensor classes = std::get<1>(result);

probs = probs.to(torch::kCPU);
classes = classes.to(torch::kCPU);
// Convert tensors to C++ vectors
std::vector<float> probs_vector(probs.data_ptr<float>(),
probs.data_ptr<float>() + probs.numel());
std::vector<long> classes_vector(
classes.data_ptr<long>(), classes.data_ptr<long>() + classes.numel());

// Create a JSON object using folly::dynamic
folly::dynamic json_response = folly::dynamic::object;
// Create a folly::dynamic array to hold tensor elements
folly::dynamic probability = folly::dynamic::array;
folly::dynamic class_names = folly::dynamic::array;

// Iterate through tensor elements and add them to the dynamic_array
for (const float& value : probs_vector) {
probability.push_back(value);
}
for (const long& value : classes_vector) {
class_names.push_back(value);
}
// Add key-value pairs to the JSON object
json_response["probability"] = probability;
json_response["classes"] = class_names;

// Serialize the JSON object to a string
std::string json_str = folly::toJson(json_response);

// Serialize and set the response
response->SetResponse(200, "data_tpye",
torchserve::PayloadType::kDATA_TYPE_BYTES,
json_str);
} catch (const std::runtime_error& e) {
LOG(ERROR) << "Failed to load tensor for request id:" << kv.second
<< ", error: " << e.what();
auto response = (*response_batch)[kv.second];
response->SetResponse(500, "data_tpye",
torchserve::PayloadType::kDATA_TYPE_STRING,
"runtime_error, failed to load tensor");
throw e;
} catch (const c10::Error& e) {
LOG(ERROR) << "Failed to load tensor for request id:" << kv.second
<< ", c10 error: " << e.msg();
auto response = (*response_batch)[kv.second];
response->SetResponse(500, "data_tpye",
torchserve::PayloadType::kDATA_TYPE_STRING,
"c10 error, failed to load tensor");
throw e;
}
}
}

} // namespace resnet

#if defined(__linux__) || defined(__APPLE__)
extern "C" {
torchserve::torchscripted::BaseHandler* allocatorResnetHandler() {
return new resnet::ResnetHandler();
}

void deleterResnetHandler(torchserve::torchscripted::BaseHandler* p) {
if (p != nullptr) {
delete static_cast<resnet::ResnetHandler*>(p);
}
}
}
#endif
28 changes: 28 additions & 0 deletions cpp/src/examples/image_classifier/resnet-18/resnet-18_handler.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef RESNET_HANDLER_HH_
#define RESNET_HANDLER_HH_

#include "src/backends/torch_scripted/handler/base_handler.hh"

namespace resnet {
class ResnetHandler : public torchserve::torchscripted::BaseHandler {
public:
// NOLINTBEGIN(bugprone-exception-escape)
ResnetHandler() = default;
// NOLINTEND(bugprone-exception-escape)
~ResnetHandler() override = default;

std::vector<torch::jit::IValue> Preprocess(
std::shared_ptr<torch::Device>& device,
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceRequestBatch>& request_batch,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch)
override;

void Postprocess(
const torch::Tensor& data,
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch)
override;
};
} // namespace resnet
#endif // RESNET_HANDLER_HH_
10 changes: 10 additions & 0 deletions cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ TEST_F(TorchScriptedBackendTest, TestLoadPredictMnistHandler) {
"mnist_ts", 200);
}

TEST_F(TorchScriptedBackendTest, TestLoadPredictResnetHandler) {
this->LoadPredict(
std::make_shared<torchserve::LoadModelRequest>(
"test/resources/torchscript_model/resnet-18/resnet-18_handler",
"resnet-18", -1, "", "", 1, false),
"test/resources/torchscript_model/resnet-18/resnet-18_handler",
"test/resources/torchscript_model/resnet-18/kitten.jpg", "resnet-18_ts",
200);
}

TEST_F(TorchScriptedBackendTest, TestBackendInitWrongModelDir) {
auto result = backend_->Initialize("test/resources/torchscript_model/mnist");
ASSERT_EQ(result, false);
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading