Skip to content

Commit

Permalink
[Model] Yolov5/v5lite/v6/v7/v7end2end: CUDA preprocessing (#370)
Browse files Browse the repository at this point in the history
* add yolo cuda preprocessing

* cmake build cuda src

* yolov5 support cuda preprocessing

* yolov5 cuda preprocessing configurable

* yolov5 update get mat data api

* yolov5 check cuda preprocess args

* refactor cuda function name

* yolo cuda preprocess padding value configurable

* yolov5 release cuda memory

* cuda preprocess pybind api update

* move use_cuda_preprocessing option to yolov5 model

* yolov5lite cuda preprocessing

* yolov6 cuda preprocessing

* yolov7 cuda preprocessing

* yolov7_e2e cuda preprocessing

* remove cuda preprocessing in runtime option

* refine log and cmake variable name

* fix model runtime ptr type

Co-authored-by: Jason <jiangjiajun@baidu.com>
  • Loading branch information
wang-xinyu and jiangjiajun authored Oct 19, 2022
1 parent 4b3e932 commit c8d6c82
Show file tree
Hide file tree
Showing 26 changed files with 752 additions and 24 deletions.
23 changes: 22 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ if(BUILD_ON_JETSON)
set(ENABLE_ORT_BACKEND ON)
endif()

# Whether to build CUDA source files in fastdeploy
# CUDA source files include CUDA preprocessing, TRT plugins, etc.
if(WITH_GPU AND UNIX)
set(BUILD_CUDA_SRC ON)
enable_language(CUDA)
set(CUDA_PROPAGATE_HOST_FLAGS FALSE)
else()
set(BUILD_CUDA_SRC OFF)
endif()

# config GIT_URL with github mirrors to speed up dependent repos clone
option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL})
if(NOT GIT_URL)
Expand Down Expand Up @@ -174,6 +184,7 @@ file(GLOB_RECURSE DEPLOY_TRT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastde
file(GLOB_RECURSE DEPLOY_OPENVINO_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/openvino/*.cc)
file(GLOB_RECURSE DEPLOY_LITE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/lite/*.cc)
file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc)
file(GLOB_RECURSE DEPLOY_VISION_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cu)
file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc)
file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc)
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS} ${DEPLOY_POROS_SRCS} ${DEPLOY_TRT_SRCS} ${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS} ${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS})
Expand Down Expand Up @@ -373,6 +384,10 @@ if(ENABLE_VISION)
endif()
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp)
list(APPEND DEPEND_LIBS yaml-cpp)
if(BUILD_CUDA_SRC)
add_definitions(-DENABLE_CUDA_PREPROCESS)
list(APPEND DEPLOY_VISION_SRCS ${DEPLOY_VISION_CUDA_SRCS})
endif()
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_VISION_SRCS})
include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
include(${PROJECT_SOURCE_DIR}/cmake/opencv.cmake)
Expand Down Expand Up @@ -428,7 +443,13 @@ elseif(ANDROID)
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_MINSIZEREL ${COMMON_LINK_FLAGS_REL})
elseif(MSVC)
else()
set_target_properties(${LIBRARY_NAME} PROPERTIES COMPILE_FLAGS "-fvisibility=hidden")
if(BUILD_CUDA_SRC)
set_target_properties(${LIBRARY_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
set_target_properties(${LIBRARY_NAME} PROPERTIES INTERFACE_COMPILE_OPTIONS
"$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CXX>>:-fvisibility=hidden>$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=-fvisibility=hidden>")
else()
set_target_properties(${LIBRARY_NAME} PROPERTIES COMPILE_FLAGS "-fvisibility=hidden")
endif()
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-Wl,--exclude-libs,ALL")
set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS_RELEASE -s)
endif()
Expand Down
1 change: 1 addition & 0 deletions cmake/summary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ function(fastdeploy_summary)
message(STATUS " WITH_GPU : ${WITH_GPU}")
message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}")
message(STATUS " TRT_DRECTORY : ${TRT_DIRECTORY}")
message(STATUS " BUILD_CUDA_SRC : ${BUILD_CUDA_SRC}")
endif()
message(STATUS " ENABLE_VISION : ${ENABLE_VISION}")
message(STATUS " ENABLE_TEXT : ${ENABLE_TEXT}")
Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/fastdeploy_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ bool FastDeployModel::InitRuntime() {
}

if (is_supported) {
runtime_ = std::unique_ptr<Runtime>(new Runtime());
runtime_ = std::shared_ptr<Runtime>(new Runtime());
if (!runtime_->Init(runtime_option)) {
return false;
}
Expand Down Expand Up @@ -107,7 +107,7 @@ bool FastDeployModel::CreateCpuBackend() {
continue;
}
runtime_option.backend = valid_cpu_backends[i];
runtime_ = std::unique_ptr<Runtime>(new Runtime());
runtime_ = std::shared_ptr<Runtime>(new Runtime());
if (!runtime_->Init(runtime_option)) {
return false;
}
Expand All @@ -130,7 +130,7 @@ bool FastDeployModel::CreateGpuBackend() {
continue;
}
runtime_option.backend = valid_gpu_backends[i];
runtime_ = std::unique_ptr<Runtime>(new Runtime());
runtime_ = std::shared_ptr<Runtime>(new Runtime());
if (!runtime_->Init(runtime_option)) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/fastdeploy_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class FASTDEPLOY_DECL FastDeployModel {
std::vector<Backend> valid_external_backends;

private:
std::unique_ptr<Runtime> runtime_;
std::shared_ptr<Runtime> runtime_;
bool runtime_initialized_ = false;
// whether to record inference time
bool enable_record_time_of_runtime_ = false;
Expand Down
96 changes: 91 additions & 5 deletions fastdeploy/vision/detection/contrib/yolov5.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
#ifdef ENABLE_CUDA_PREPROCESS
#include "fastdeploy/vision/utils/cuda_utils.h"
#endif // ENABLE_CUDA_PREPROCESS

namespace fastdeploy {
namespace vision {
Expand Down Expand Up @@ -104,9 +107,20 @@ bool YOLOv5::Initialize() {
// if (!is_dynamic_input_) {
// is_mini_pad_ = false;
// }

return true;
}

YOLOv5::~YOLOv5() {
#ifdef ENABLE_CUDA_PREPROCESS
if (use_cuda_preprocessing_) {
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
}
#endif // ENABLE_CUDA_PREPROCESS
}

bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size,
Expand Down Expand Up @@ -156,6 +170,69 @@ bool YOLOv5::Preprocess(Mat* mat, FDTensor* output,
return true;
}

void YOLOv5::UseCudaPreprocessing(int max_image_size) {
#ifdef ENABLE_CUDA_PREPROCESS
use_cuda_preprocessing_ = true;
is_scale_up_ = true;
if (input_img_cuda_buffer_host_ == nullptr) {
// prepare input data cache in GPU pinned memory
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
// prepare input data cache in GPU device memory
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size_[0] * size_[1] * sizeof(float)));
}
#else
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
<< std::endl;
use_cuda_preprocessing_ = false;
#endif
}

bool YOLOv5::CudaPreprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size,
const std::vector<float> padding_value,
bool is_mini_pad, bool is_no_pad, bool is_scale_up,
int stride, float max_wh, bool multi_label) {
#ifdef ENABLE_CUDA_PREPROCESS
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
return false;
}

// Record the shape of image and the shape of preprocessed image
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};

cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
input_img_cuda_buffer_host_,
src_img_buf_size, cudaMemcpyHostToDevice, stream));
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
mat->Height(), input_tensor_cuda_buffer_device_,
size[0], size[1], padding_value, stream);
cudaStreamSynchronize(stream);
cudaStreamDestroy(stream);

// Record output shape of preprocessed image
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};

output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
input_tensor_cuda_buffer_device_);
output->device = Device::GPU;
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
#else
FDERROR << "CUDA src code was not enabled." << std::endl;
return false;
#endif // ENABLE_CUDA_PREPROCESS
}

bool YOLOv5::Postprocess(
std::vector<FDTensor>& infer_results, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
Expand Down Expand Up @@ -262,11 +339,20 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold,

std::map<std::string, std::array<float, 2>> im_info;

if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
if (use_cuda_preprocessing_) {
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
} else {
if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_,
is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_,
multi_label_)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
}

input_tensors[0].name = InputInfoOfRuntime(0).name;
Expand Down
22 changes: 22 additions & 0 deletions fastdeploy/vision/detection/contrib/yolov5.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#pragma once

#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
Expand All @@ -27,6 +28,8 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::ONNX);

~YOLOv5();

std::string ModelName() const { return "yolov5"; }

virtual bool Predict(cv::Mat* im, DetectionResult* result,
Expand All @@ -42,6 +45,17 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
bool is_scale_up = false, int stride = 32,
float max_wh = 7680.0, bool multi_label = true);

void UseCudaPreprocessing(int max_img_size = 3840 * 2160);

bool CudaPreprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info,
const std::vector<int>& size = {640, 640},
const std::vector<float> padding_value = {114.0, 114.0,
114.0},
bool is_mini_pad = false, bool is_no_pad = false,
bool is_scale_up = false, int stride = 32,
float max_wh = 7680.0, bool multi_label = true);

static bool Postprocess(
std::vector<FDTensor>& infer_results, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
Expand Down Expand Up @@ -85,6 +99,14 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel {
// value will
// auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_;
// CUDA host buffer for input image
uint8_t* input_img_cuda_buffer_host_ = nullptr;
// CUDA device buffer for input image
uint8_t* input_img_cuda_buffer_device_ = nullptr;
// CUDA device buffer for TRT input tensor
float* input_tensor_cuda_buffer_device_ = nullptr;
// Whether to use CUDA preprocessing
bool use_cuda_preprocessing_ = false;
};

} // namespace detection
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/vision/detection/contrib/yolov5_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ void BindYOLOv5(pybind11::module& m) {
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def("use_cuda_preprocessing",
[](vision::detection::YOLOv5& self, int max_image_size) {
self.UseCudaPreprocessing(max_image_size);
})
.def_static("preprocess",
[](pybind11::array& data, const std::vector<int>& size,
const std::vector<float> padding_value, bool is_mini_pad,
Expand Down
85 changes: 82 additions & 3 deletions fastdeploy/vision/detection/contrib/yolov5lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#include "fastdeploy/vision/detection/contrib/yolov5lite.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
#ifdef ENABLE_CUDA_PREPROCESS
#include "fastdeploy/vision/utils/cuda_utils.h"
#endif // ENABLE_CUDA_PREPROCESS

namespace fastdeploy {
namespace vision {
Expand Down Expand Up @@ -136,6 +139,16 @@ bool YOLOv5Lite::Initialize() {
return true;
}

YOLOv5Lite::~YOLOv5Lite() {
#ifdef ENABLE_CUDA_PREPROCESS
if (use_cuda_preprocessing_) {
CUDA_CHECK(cudaFreeHost(input_img_cuda_buffer_host_));
CUDA_CHECK(cudaFree(input_img_cuda_buffer_device_));
CUDA_CHECK(cudaFree(input_tensor_cuda_buffer_device_));
}
#endif // ENABLE_CUDA_PREPROCESS
}

bool YOLOv5Lite::Preprocess(
Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
Expand Down Expand Up @@ -176,6 +189,65 @@ bool YOLOv5Lite::Preprocess(
return true;
}

void YOLOv5Lite::UseCudaPreprocessing(int max_image_size) {
#ifdef ENABLE_CUDA_PREPROCESS
use_cuda_preprocessing_ = true;
is_scale_up = true;
if (input_img_cuda_buffer_host_ == nullptr) {
// prepare input data cache in GPU pinned memory
CUDA_CHECK(cudaMallocHost((void**)&input_img_cuda_buffer_host_, max_image_size * 3));
// prepare input data cache in GPU device memory
CUDA_CHECK(cudaMalloc((void**)&input_img_cuda_buffer_device_, max_image_size * 3));
CUDA_CHECK(cudaMalloc((void**)&input_tensor_cuda_buffer_device_, 3 * size[0] * size[1] * sizeof(float)));
}
#else
FDWARNING << "The FastDeploy didn't compile with BUILD_CUDA_SRC=ON."
<< std::endl;
use_cuda_preprocessing_ = false;
#endif
}

bool YOLOv5Lite::CudaPreprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
#ifdef ENABLE_CUDA_PREPROCESS
if (is_mini_pad != false || is_no_pad != false || is_scale_up != true) {
FDERROR << "Preprocessing with CUDA is only available when the arguments satisfy (is_mini_pad=false, is_no_pad=false, is_scale_up=true)." << std::endl;
return false;
}

// Record the shape of image and the shape of preprocessed image
(*im_info)["input_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};

cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));
int src_img_buf_size = mat->Height() * mat->Width() * mat->Channels();
memcpy(input_img_cuda_buffer_host_, mat->Data(), src_img_buf_size);
CUDA_CHECK(cudaMemcpyAsync(input_img_cuda_buffer_device_,
input_img_cuda_buffer_host_,
src_img_buf_size, cudaMemcpyHostToDevice, stream));
utils::CudaYoloPreprocess(input_img_cuda_buffer_device_, mat->Width(),
mat->Height(), input_tensor_cuda_buffer_device_,
size[0], size[1], padding_value, stream);
cudaStreamSynchronize(stream);
cudaStreamDestroy(stream);

// Record output shape of preprocessed image
(*im_info)["output_shape"] = {static_cast<float>(size[0]), static_cast<float>(size[1])};

output->SetExternalData({mat->Channels(), size[0], size[1]}, FDDataType::FP32,
input_tensor_cuda_buffer_device_);
output->device = Device::GPU;
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
#else
FDERROR << "CUDA src code was not enabled." << std::endl;
return false;
#endif // ENABLE_CUDA_PREPROCESS
}

bool YOLOv5Lite::PostprocessWithDecode(
FDTensor& infer_result, DetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
Expand Down Expand Up @@ -348,9 +420,16 @@ bool YOLOv5Lite::Predict(cv::Mat* im, DetectionResult* result,
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};

if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
if (use_cuda_preprocessing_) {
if (!CudaPreprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
} else {
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
}

input_tensors[0].name = InputInfoOfRuntime(0).name;
Expand Down
Loading

0 comments on commit c8d6c82

Please sign in to comment.