Skip to content

Commit

Permalink
[Backend] refactor paddle custom ops -> fastdeploy::paddle_custom_ops (
Browse files Browse the repository at this point in the history
…#2101)

* [cmake] upgrade windows paddle inference -> 2.5.0

* [cmake] upgrade windows paddle inference -> 2.5.0

* fix paddle custom ops bug on windows

* [Backend] refactor paddle custom ops
  • Loading branch information
DefTruth committed Jul 13, 2023
1 parent 2542a75 commit 99c2b65
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 17 deletions.
28 changes: 26 additions & 2 deletions benchmark/cpp/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ DEFINE_string(optimized_model_dir, "",
DEFINE_bool(collect_trt_shape_by_device, false,
"Optional, whether collect trt shape by device. "
"default false.");
DEFINE_double(custom_tensor_value, 1.0,
"Optional, set the value for fd tensor, "
"default 1.0");
DEFINE_bool(collect_trt_shape_by_custom_tensor_value, false,
"Optional, whether collect trt shape by custom tensor value. "
"default false.");

#if defined(ENABLE_BENCHMARK)
static std::vector<int64_t> GetInt64Shape(const std::vector<int>& shape) {
Expand Down Expand Up @@ -208,6 +214,23 @@ static void RuntimeProfiling(int argc, char* argv[]) {
for (int i = 0; i < input_shapes.size(); ++i) {
option.trt_option.SetShape(input_names[i], trt_shapes[i * 3],
trt_shapes[i * 3 + 1], trt_shapes[i * 3 + 2]);
// Set custom input data for collect trt shapes
if (FLAGS_collect_trt_shape_by_custom_tensor_value) {
int min_shape_num = std::accumulate(trt_shapes[i * 3].begin(),
trt_shapes[i * 3].end(), 1,
std::multiplies<int>());
int opt_shape_num = std::accumulate(trt_shapes[i * 3 + 1].begin(),
trt_shapes[i * 3 + 1].end(), 1,
std::multiplies<int>());
int max_shape_num = std::accumulate(trt_shapes[i * 3 + 2].begin(),
trt_shapes[i * 3 + 2].end(), 1,
std::multiplies<int>());
std::vector<float> min_input_data(min_shape_num, FLAGS_custom_tensor_value);
std::vector<float> opt_input_data(opt_shape_num, FLAGS_custom_tensor_value);
std::vector<float> max_input_data(max_shape_num, FLAGS_custom_tensor_value);
option.trt_option.SetInputData(input_names[i], min_input_data,
opt_input_data, max_input_data);
}
}
}

Expand All @@ -232,8 +255,9 @@ static void RuntimeProfiling(int argc, char* argv[]) {
// Feed inputs, all values set as 1.
std::vector<fastdeploy::FDTensor> inputs(runtime.NumInputs());
for (int i = 0; i < inputs.size(); ++i) {
fastdeploy::function::Full(1, GetInt64Shape(input_shapes[i]), &inputs[i],
input_dtypes[i]);
fastdeploy::function::Full(
FLAGS_custom_tensor_value, GetInt64Shape(input_shapes[i]),
&inputs[i], input_dtypes[i]);
inputs[i].name = input_names[i];
}

Expand Down
3 changes: 3 additions & 0 deletions benchmark/paddlex/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_struc
add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc)
add_executable(benchmark_ppshituv2_rec ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_rec.cc)
add_executable(benchmark_ppshituv2_det ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_det.cc)
add_executable(benchmark_pp3d_centerpoint ${PROJECT_SOURCE_DIR}/benchmark_pp3d_centerpoint.cc)

if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags pthread)
Expand All @@ -33,6 +34,7 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags pthread)
else()
target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppcls ${FASTDEPLOY_LIBS} gflags)
Expand All @@ -44,6 +46,7 @@ else()
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags)
endif()
# only for Android ADB test
if(ANDROID)
Expand Down
100 changes: 100 additions & 0 deletions benchmark/paddlex/benchmark_pp3d_centerpoint.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "flags.h"
#include "macros.h"
#include "option.h"

namespace vision = fastdeploy::vision;
namespace benchmark = fastdeploy::benchmark;


static bool ReadTestPoint(const std::string &file_path,
std::vector<float> &data) {
int with_timelag = 0;
int64_t num_point_dim = 5;
std::ifstream file_in(file_path, std::ios::in | std::ios::binary);

if (!file_in) {
std::cout << "Failed to read file: " << file_path << std::endl;
return false;
}

std::streampos file_size;
file_in.seekg(0, std::ios::end);
file_size = file_in.tellg();
file_in.seekg(0, std::ios::beg);

data.resize(file_size / sizeof(float));

file_in.read(reinterpret_cast<char *>(data.data()), file_size);
file_in.close();

if (file_size / sizeof(float) % num_point_dim != 0) {
std::cout << "Loaded file size (" << file_size
<< ") is not evenly divisible by num_point_dim (" << num_point_dim
<< ")\n";
return false;
}
size_t num_points = file_size / sizeof(float) / num_point_dim;
if (!with_timelag && num_point_dim == 5 || num_point_dim > 5) {
for (int64_t i = 0; i < num_points; ++i) {
data[i * num_point_dim + 4] = 0.;
}
}
return true;
}

int main(int argc, char* argv[]) {
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
// Initialization
auto option = fastdeploy::RuntimeOption();
if (!CreateRuntimeOption(&option, argc, argv, true)) {
return -1;
}
std::string point_dir = FLAGS_image;
std::unordered_map<std::string, std::string> config_info;
benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path,
&config_info);
std::string model_name, params_name, config_name;
auto model_format = fastdeploy::ModelFormat::PADDLE;
if (!UpdateModelResourceName(&model_name, &params_name, &config_name,
&model_format, config_info, false)) {
return -1;
}
auto model_file = FLAGS_model + sep + model_name;
auto params_file = FLAGS_model + sep + params_name;
if (config_info["backend"] == "paddle_trt") {
option.paddle_infer_option.collect_trt_shape = true;
option.paddle_infer_option.collect_trt_shape_by_device = true;
}
if (config_info["backend"] == "paddle_trt" ||
config_info["backend"] == "trt") {
option.trt_option.SetShape("data", {34752, 5}, {34752, 5},
{34752, 5});
std::vector<float> min_input_data;
ReadTestPoint(point_dir, min_input_data);
// use custom data to perform collect shapes.
option.trt_option.SetInputData("data", min_input_data);
}
auto model_centerpoint = vision::perception::Centerpoint(
model_file, params_file, "", option, model_format);
vision::PerceptionResult res;
// Run profiling
BENCHMARK_MODEL(model_centerpoint, model_centerpoint.Predict(point_dir, &res))
// std::cout << res.Str() << std::endl;
#endif

return 0;
}
15 changes: 14 additions & 1 deletion benchmark/paddlex/get_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,24 @@ download PP-OCRv4-server-det.tgz

# PP-ShiTuV2
download PP-ShiTuv2-rec.tgz
download PP-ShiTuv2-det.tgz

# PP-StructureV2
download PP-Structurev2-layout.tgz
download PP-Structurev2-SLANet.tgz
download PP-Structurev2-vi-layoutxlm.tgz

# Paddle3D
download CADNN_OCRNet-HRNetW18.tgz
download CenterPoint-Pillars-02Voxel.tgz
download PETRv1_v99.tgz
download PETRv2_v99.tgz

# Test resources
# PaddleClas
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/ppcls_cls_demo.JPEG

# PaddleDetection
# PaddleDetection & ppshitu-det
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/ppdet_det_img.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/ppdet_det_img_800x800.jpg

Expand All @@ -93,3 +100,9 @@ wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/structurev2_layout_v
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/structurev2_vi_layoutxml_zh_val_0.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/table_structure_dict_ch.txt
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/layout_cdla_dict.txt

# Paddle3D
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_cadnn_kitti_000780.png
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_petrv1_v99_nuscenes_sample_6.tgz && tar -zxvf paddle3d_petrv1_v99_nuscenes_sample_6.tgz
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_petrv2_v99_nuscenes_sample_12.tgz && tar -zxvf paddle3d_petrv2_v99_nuscenes_sample_12.tgz
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#include "paddle/extension.h"
#endif

namespace fastdeploy {
namespace paddle_custom_ops {

std::vector<paddle::Tensor> postprocess_gpu(
const std::vector<paddle::Tensor> &hm,
const std::vector<paddle::Tensor> &reg,
Expand Down Expand Up @@ -97,19 +100,22 @@ std::vector<paddle::DataType> PostProcessInferDtype(
return {reg_dtype[0], hm_dtype[0], paddle::DataType::INT64};
}

} // namespace fastdeploy
} // namespace paddle_custom_ops

PD_BUILD_OP(centerpoint_postprocess)
.Inputs({paddle::Vec("HM"), paddle::Vec("REG"), paddle::Vec("HEIGHT"),
paddle::Vec("DIM"), paddle::Vec("VEL"), paddle::Vec("ROT")})
.Outputs({"BBOXES", "SCORES", "LABELS"})
.SetKernelFn(PD_KERNEL(centerpoint_postprocess))
.SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::centerpoint_postprocess))
.Attrs({"voxel_size: std::vector<float>",
"point_cloud_range: std::vector<float>",
"post_center_range: std::vector<float>",
"num_classes: std::vector<int>", "down_ratio: int",
"score_threshold: float", "nms_iou_threshold: float",
"nms_pre_max_size: int", "nms_post_max_size: int",
"with_velocity: bool"})
.SetInferShapeFn(PD_INFER_SHAPE(PostProcessInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PostProcessInferDtype));
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::PostProcessInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::PostProcessInferDtype));

#endif // WITH_GPU
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
#include "paddle/extension.h"
#endif

namespace fastdeploy {
namespace paddle_custom_ops {

#define CHECK_INPUT_CUDA(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")

#define CHECK_INPUT_BATCHSIZE(x) \
Expand Down Expand Up @@ -284,3 +287,6 @@ std::vector<paddle::Tensor> postprocess_gpu(
auto out_bboxes = paddle::experimental::concat(bboxes, 0);
return {out_bboxes, out_scores, out_labels};
}

} // namespace fastdeploy
} // namespace paddle_custom_ops
7 changes: 7 additions & 0 deletions fastdeploy/runtime/backends/paddle/ops/iou3d_nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ All Rights Reserved 2019-2020.
*/

#include <stdio.h>

namespace fastdeploy {
namespace paddle_custom_ops {

#define THREADS_PER_BLOCK 16
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

Expand Down Expand Up @@ -315,3 +319,6 @@ void NmsLauncher(const cudaStream_t &stream, const float *bboxes,
num_bboxes, num_bboxes_for_nms, nms_overlap_thresh, decode_bboxes_dims,
bboxes, index, sorted_index, mask);
}

} // namespace fastdeploy
} // namespace paddle_custom_ops
21 changes: 12 additions & 9 deletions fastdeploy/runtime/backends/paddle/ops/voxelize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

#include <vector>

#if defined(WITH_GPU)

#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
Expand All @@ -24,6 +22,9 @@
#include "paddle/extension.h"
#endif

namespace fastdeploy {
namespace paddle_custom_ops {

template <typename T, typename T_int>
bool hard_voxelize_cpu_kernel(
const T *points, const float point_cloud_range_x_min,
Expand Down Expand Up @@ -147,7 +148,8 @@ std::vector<paddle::Tensor> hard_voxelize_cpu(
return {voxels, coords, num_points_per_voxel, num_voxels};
}

#ifdef PADDLE_WITH_CUDA

#if defined(PADDLE_WITH_CUDA) && defined(WITH_GPU)
std::vector<paddle::Tensor> hard_voxelize_cuda(
const paddle::Tensor &points, const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range, int max_num_points_in_voxel,
Expand All @@ -161,7 +163,7 @@ std::vector<paddle::Tensor> hard_voxelize(
if (points.is_cpu()) {
return hard_voxelize_cpu(points, voxel_size, point_cloud_range,
max_num_points_in_voxel, max_voxels);
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) && defined(WITH_GPU)
} else if (points.is_gpu() || points.is_gpu_pinned()) {
return hard_voxelize_cuda(points, voxel_size, point_cloud_range,
max_num_points_in_voxel, max_voxels);
Expand All @@ -188,14 +190,15 @@ std::vector<paddle::DataType> HardInferDtype(paddle::DataType points_dtype) {
paddle::DataType::INT32};
}

} // namespace fastdeploy
} // namespace paddle_custom_ops

PD_BUILD_OP(hard_voxelize)
.Inputs({"POINTS"})
.Outputs({"VOXELS", "COORS", "NUM_POINTS_PER_VOXEL", "num_voxels"})
.SetKernelFn(PD_KERNEL(hard_voxelize))
.SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::hard_voxelize))
.Attrs({"voxel_size: std::vector<float>",
"point_cloud_range: std::vector<float>",
"max_num_points_in_voxel: int", "max_voxels: int"})
.SetInferShapeFn(PD_INFER_SHAPE(HardInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(HardInferDtype));

#endif // WITH_GPU
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::HardInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::HardInferDtype));
6 changes: 6 additions & 0 deletions fastdeploy/runtime/backends/paddle/ops/voxelize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
#include "paddle/extension.h"
#endif

namespace fastdeploy {
namespace paddle_custom_ops {

#define CHECK_INPUT_CUDA(x) \
PD_CHECK(x.is_gpu() || x.is_gpu_pinned(), #x " must be a GPU Tensor.")

Expand Down Expand Up @@ -349,3 +352,6 @@ std::vector<paddle::Tensor> hard_voxelize_cuda(

return {voxels, coords, num_points_per_voxel, num_voxels};
}

} // namespace fastdeploy
} // namespace paddle_custom_ops
2 changes: 2 additions & 0 deletions fastdeploy/runtime/backends/paddle/option.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ struct PaddleBackendOption {
int gpu_mem_init_size = 100;
/// The option to enable fixed size optimization for transformer model
bool enable_fixed_size_opt = false;
/// min_subgraph_size for paddle-trt
int trt_min_subgraph_size = 3;

/// Disable type of operators run on TensorRT
void DisableTrtOps(const std::vector<std::string>& ops) {
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/runtime/backends/paddle/option_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ void BindPaddleOption(pybind11::module& m) {
&PaddleBackendOption::is_quantize_model)
.def_readwrite("inference_precision", &PaddleBackendOption::inference_precision)
.def_readwrite("enable_inference_cutlass",&PaddleBackendOption::enable_inference_cutlass)
.def_readwrite("trt_min_subgraph_size",&PaddleBackendOption::trt_min_subgraph_size)
.def("disable_trt_ops", &PaddleBackendOption::DisableTrtOps)
.def("delete_pass", &PaddleBackendOption::DeletePass)
.def("set_ipu_config", &PaddleBackendOption::SetIpuConfig);
Expand Down
6 changes: 4 additions & 2 deletions fastdeploy/runtime/backends/paddle/paddle_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
config_.SetOptimCacheDir(opt_cache_dir);
}
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
option.trt_option.max_batch_size, 3,
option.trt_option.max_batch_size,
option.trt_min_subgraph_size,
precision, use_static);
SetTRTDynamicShapeToConfig(option);
if (option_.enable_fixed_size_opt) {
Expand Down Expand Up @@ -225,7 +226,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model,
use_static = true;
}
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
option.trt_option.max_batch_size, 3,
option.trt_option.max_batch_size,
option.trt_min_subgraph_size,
paddle_infer::PrecisionType::kInt8,
use_static, false);
SetTRTDynamicShapeToConfig(option);
Expand Down

0 comments on commit 99c2b65

Please sign in to comment.