Skip to content

Commit

Permalink
[Sync][Internal] sync some internal paddle3d codes (#2108)
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth committed Jul 13, 2023
1 parent 77ee48f commit 681ccc4
Show file tree
Hide file tree
Showing 30 changed files with 2,517 additions and 45 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ endif()
if(ENABLE_ENCRYPTION)
add_definitions(-DENABLE_ENCRYPTION)
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ENCRYPTION_SRCS})
include(${PROJECT_SOURCE_DIR}/cmake/gflags.cmake)
# include(${PROJECT_SOURCE_DIR}/cmake/gflags.cmake)
include(${PROJECT_SOURCE_DIR}/cmake/openssl.cmake)
list(APPEND DEPEND_LIBS ${OPENSSL_LIBRARIES})
endif()
Expand Down
3 changes: 3 additions & 0 deletions benchmark/paddlex/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_struc
add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc)
add_executable(benchmark_ppshituv2_rec ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_rec.cc)
add_executable(benchmark_ppshituv2_det ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_det.cc)
add_executable(benchmark_pp3d_cadnn ${PROJECT_SOURCE_DIR}/benchmark_pp3d_cadnn.cc)
add_executable(benchmark_pp3d_centerpoint ${PROJECT_SOURCE_DIR}/benchmark_pp3d_centerpoint.cc)

if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
Expand All @@ -34,6 +35,7 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_pp3d_cadnn ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags pthread)
else()
target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags)
Expand All @@ -46,6 +48,7 @@ else()
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_pp3d_cadnn ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags)
endif()
# only for Android ADB test
Expand Down
8 changes: 8 additions & 0 deletions benchmark/paddlex/benchmark_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,17 @@ fi

# PP-ShiTuV2
./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH
./benchmark_ppshituv2_det --model PP-ShiTuv2-det --image ppdet_det_img.jpg --config_path $CONFIG_PATH

# PP-StructureV2
./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH
./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH
./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --trt_shapes "1,512:1,512:1,512:1,512,4:1,512,4:1,512,4:1,512:1,512:1,512:1,512:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --config_path $CONFIG_PATH

# Paddle3D
./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH

set +x
8 changes: 8 additions & 0 deletions benchmark/paddlex/benchmark_gpu_trt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,17 @@ fi

# PP-ShiTuV2
./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH
./benchmark_ppshituv2_det --model PP-ShiTuv2-det --image ppdet_det_img.jpg --config_path $CONFIG_PATH

# PP-StructureV2
./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH
./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH
./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --trt_shapes "1,512:1,512:1,512:1,512,4:1,512,4:1,512,4:1,512:1,512:1,512:1,512:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --collect_trt_shape_by_custom_tensor_value --collect_trt_shape_by_device --config_path $CONFIG_PATH

# Paddle3D
./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --trt_shapes "1,6,3,320,800:1,6,3,320,800:1,6,3,320,800:1,6,4,4:1,6,4,4:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --trt_shapes "1,12,3,320,800:1,12,3,320,800:1,12,3,320,800:1,12,4,4:1,12,4,4:1,12,4,4:1,12:1,12:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH

set +x
81 changes: 81 additions & 0 deletions benchmark/paddlex/benchmark_pp3d_cadnn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "flags.h"
#include "macros.h"
#include "option.h"

namespace vision = fastdeploy::vision;
namespace benchmark = fastdeploy::benchmark;

int main(int argc, char* argv[]) {
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
// Initialization
auto option = fastdeploy::RuntimeOption();
if (!CreateRuntimeOption(&option, argc, argv, true)) {
return -1;
}
auto im = cv::imread(FLAGS_image);
std::unordered_map<std::string, std::string> config_info;
benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path,
&config_info);
std::string model_name, params_name, config_name;
auto model_format = fastdeploy::ModelFormat::PADDLE;
if (!UpdateModelResourceName(&model_name, &params_name, &config_name,
&model_format, config_info, false)) {
return -1;
}
auto model_file = FLAGS_model + sep + model_name;
auto params_file = FLAGS_model + sep + params_name;
std::vector<float> cam_data{7.183351e+02, 0.000000e+00, 6.003891e+02,
4.450382e+01, 0.000000e+00, 7.183351e+02,
1.815122e+02, -5.951107e-01, 0.000000e+00,
0.000000e+00, 1.000000e+00, 2.616315e-03};
std::vector<float> lidar_data = {
0.0048523, -0.9999298, -0.01081266, -0.00711321,
-0.00302069, 0.01079808, -0.99993706, -0.06176636,
0.99998367, 0.00488465, -0.00296808, -0.26739058,
0., 0., 0., 1.};
if (config_info["backend"] == "paddle_trt") {
option.paddle_infer_option.collect_trt_shape = true;
option.paddle_infer_option.collect_trt_shape_by_device = true;
option.paddle_infer_option.trt_min_subgraph_size = 12;
option.paddle_infer_option.DisableTrtOps({"squeeze2"});
option.trt_option.max_batch_size = 1;
}
if (config_info["backend"] == "paddle_trt" ||
config_info["backend"] == "trt") {
// use custom data to perform collect shapes.
option.trt_option.SetShape("images", {1, 3, 375, 1242},
{1, 3, 375, 1242}, {1, 3, 375, 1242});
option.trt_option.SetShape("trans_lidar_to_cam", {1, 4, 4},
{1, 4, 4}, {1, 4, 4});
option.trt_option.SetShape("trans_cam_to_img", {1, 3, 4},
{1, 3, 4}, {1, 3, 4});
std::vector<float> image_data;
image_data.assign(im.data, im.data + 1*3*375*1242);
option.trt_option.SetInputData("trans_lidar_to_cam", lidar_data);
option.trt_option.SetInputData("trans_cam_to_img", cam_data);
option.trt_option.SetInputData("images", image_data);
}
auto model_cadnn = vision::perception::Caddn(
model_file, params_file, "", option, model_format);
vision::PerceptionResult res;
// Run profiling
BENCHMARK_MODEL(model_cadnn, model_cadnn.Predict(im, cam_data, lidar_data, &res))
std::cout << res.Str() << std::endl;
#endif

return 0;
}
2 changes: 1 addition & 1 deletion benchmark/paddlex/benchmark_pp3d_centerpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ int main(int argc, char* argv[]) {
vision::PerceptionResult res;
// Run profiling
BENCHMARK_MODEL(model_centerpoint, model_centerpoint.Predict(point_dir, &res))
// std::cout << res.Str() << std::endl;
std::cout << res.Str() << std::endl;
#endif

return 0;
Expand Down
8 changes: 8 additions & 0 deletions benchmark/paddlex/benchmark_x86.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,17 @@ fi

# PP-ShiTuV2
./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH
./benchmark_ppshituv2_det --model PP-ShiTuv2-det --image ppdet_det_img.jpg --config_path $CONFIG_PATH

# PP-StructureV2
./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH
./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH
./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --config_path $CONFIG_PATH

# Paddle3D
./benchmark --model PETRv1_v99 --config_path $CONFIG_PATH --shapes "1,6,3,320,800:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --warmup 5 --repeat 20
./benchmark --model PETRv2_v99 --config_path $CONFIG_PATH --shapes "1,12,3,320,800:1,12,4,4:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --warmup 5 --repeat 20
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH

set +x
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,5 @@ PD_BUILD_OP(centerpoint_postprocess)
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::PostProcessInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::PostProcessInferDtype));

#endif // WITH_GPU
#endif // WITH_GPU

Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ std::vector<paddle::Tensor> postprocess_gpu(

// nms
// in NmsLauncher, rot = - theta - pi / 2
const int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS);
int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS);
auto nms_mask = paddle::empty({num_bboxes_for_nms * col_blocks},
paddle::DataType::INT64, paddle::GPUPlace());
int64_t *nms_mask_data = nms_mask.data<int64_t>();
Expand Down Expand Up @@ -291,4 +291,4 @@ std::vector<paddle::Tensor> postprocess_gpu(
}

} // namespace fastdeploy
} // namespace paddle_custom_ops
} // namespace paddle_custom_ops
94 changes: 94 additions & 0 deletions fastdeploy/runtime/backends/paddle/ops/grid_sample_3d.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(WITH_GPU)

#include "grid_sample_3d.h"

#include <vector>

#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif

namespace fastdeploy {
namespace paddle_custom_ops {

std::vector<paddle::Tensor> GridSample3DCUDAForward(
const paddle::Tensor& x, const paddle::Tensor& grid,
const std::string& mode, const std::string& padding_mode,
bool align_corners);

std::vector<paddle::Tensor> GridSample3DForward(const paddle::Tensor& x,
const paddle::Tensor& grid,
const std::string& mode,
const std::string& padding_mode,
bool align_corners) {
return GridSample3DCUDAForward(x, grid, mode, padding_mode, align_corners);
}

std::vector<paddle::Tensor> GridSample3DCUDABackward(
const paddle::Tensor& x, const paddle::Tensor& grid,
const paddle::Tensor& grad_out, const std::string& mode,
const std::string& padding_mode, bool align_corners);

std::vector<paddle::Tensor> GridSample3DBackward(
const paddle::Tensor& x, const paddle::Tensor& grid,
const paddle::Tensor& grad_out, const std::string& mode,
const std::string& padding_mode, bool align_corners) {
return GridSample3DCUDABackward(x, grid, grad_out, mode, padding_mode,
align_corners);
}

std::vector<std::vector<int64_t>> GridSample3DInferShape(
std::vector<int64_t> x_shape, std::vector<int64_t> grid_shape) {
return {
{x_shape[0], x_shape[1], grid_shape[1], grid_shape[2], grid_shape[3]}};
}

std::vector<std::vector<int64_t>> GridSample3DInferBackShape(
std::vector<int64_t> x_shape, std::vector<int64_t> grid_shape) {
return {x_shape};
}

std::vector<paddle::DataType> GridSample3DInferDtype(
paddle::DataType x_dtype, paddle::DataType grid_dtype) {
return {x_dtype};
}

} // namespace fastdeploy
} // namespace paddle_custom_ops

PD_BUILD_OP(grid_sample_3d)
.Inputs({"x", "grid"})
.Attrs({"mode: std::string", "padding_mode: std::string",
"align_corners: bool"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::GridSample3DForward))
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::GridSample3DInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::GridSample3DInferDtype));

PD_BUILD_GRAD_OP(grid_sample_3d)
.Inputs({"x", "grid", paddle::Grad("out")})
.Attrs({"mode: std::string", "padding_mode: std::string",
"align_corners: bool"})
.Outputs({paddle::Grad("x")})
.SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::GridSample3DBackward))
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::GridSample3DInferBackShape));

#endif
Loading

0 comments on commit 681ccc4

Please sign in to comment.