Skip to content

Commit

Permalink
[Sync][Internal] sync some internal features of paddle3d inference (#…
Browse files Browse the repository at this point in the history
…2118)

* [Sync][Internal] sync some internal codes

* [Sync][Internal] sync some internal features of paddle3d inference

* [Sync][Internal] sync some internal features of paddle3d inference
  • Loading branch information
DefTruth committed Jul 17, 2023
1 parent f413e02 commit ade27d2
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 157 deletions.
6 changes: 3 additions & 3 deletions benchmark/paddlex/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_struc
add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc)
add_executable(benchmark_ppshituv2_rec ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_rec.cc)
add_executable(benchmark_ppshituv2_det ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_det.cc)
add_executable(benchmark_pp3d_cadnn ${PROJECT_SOURCE_DIR}/benchmark_pp3d_cadnn.cc)
add_executable(benchmark_pp3d_caddn ${PROJECT_SOURCE_DIR}/benchmark_pp3d_caddn.cc)
add_executable(benchmark_pp3d_centerpoint ${PROJECT_SOURCE_DIR}/benchmark_pp3d_centerpoint.cc)

if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
Expand All @@ -35,7 +35,7 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_pp3d_cadnn ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_pp3d_caddn ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags pthread)
else()
target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags)
Expand All @@ -48,7 +48,7 @@ else()
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_pp3d_cadnn ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_pp3d_caddn ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags)
endif()
# only for Android ADB test
Expand Down
3 changes: 2 additions & 1 deletion benchmark/paddlex/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ tar -zxvf MobileNetV3_small_x1_0.tgz
## 4. 各个硬件上的一键运行脚本

在准备好相关的环境配置和SDK后,可以使用本目录提供的脚本一键运行后的benchmark数据。
- 获取模型和资源文件
- 获取模型和资源文件
```bash
./get_models.sh
```

- 运行benchmark脚本
```bash
# x86 CPU Paddle backend fp32
Expand Down
10 changes: 5 additions & 5 deletions benchmark/paddlex/benchmark_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fi

# PaddleSeg
./benchmark_ppseg --model OCRNet_HRNetW48 --image ppseg_cityscapes_demo_512x512.png --config_path $CONFIG_PATH
./benchmark_ppseg --model PP-LiteSeg-STDC1 --image ppseg_cityscapes_demo_512x512.png --config_path $CONFIG_PATH
./benchmark_ppseg --model PP-LiteSeg-STDC1 --image ppseg_cityscapes_demo_512x512.png --config_path $CONFIG_PATH
./benchmark_ppseg --model SegFormer-B0 --image ppseg_cityscapes_demo_512x512.png --config_path $CONFIG_PATH
./benchmark_ppseg --model PP-MobileSeg-Base --image ppseg_ade_val_512x512.png --config_path $CONFIG_PATH

Expand All @@ -49,9 +49,9 @@ fi
./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --trt_shapes "1,512:1,512:1,512:1,512,4:1,512,4:1,512,4:1,512:1,512:1,512:1,512:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --config_path $CONFIG_PATH

# Paddle3D
./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH
./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_caddn --model CADDN_OCRNet-HRNetW18 --image paddle3d_caddn_kitti_000780.png --config_path $CONFIG_PATH

set +x
8 changes: 4 additions & 4 deletions benchmark/paddlex/benchmark_gpu_trt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ fi
./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --trt_shapes "1,512:1,512:1,512:1,512,4:1,512,4:1,512,4:1,512:1,512:1,512:1,512:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --collect_trt_shape_by_custom_tensor_value --collect_trt_shape_by_device --config_path $CONFIG_PATH

# Paddle3D
./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --trt_shapes "1,6,3,320,800:1,6,3,320,800:1,6,3,320,800:1,6,4,4:1,6,4,4:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --trt_shapes "1,12,3,320,800:1,12,3,320,800:1,12,3,320,800:1,12,4,4:1,12,4,4:1,12,4,4:1,12:1,12:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH
./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --trt_shapes "1,6,3,320,800:1,6,3,320,800:1,6,3,320,800:1,6,4,4:1,6,4,4:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --trt_shapes "1,12,3,320,800:1,12,3,320,800:1,12,3,320,800:1,12,4,4:1,12,4,4:1,12,4,4:1,12:1,12:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_caddn --model CADDN_OCRNet-HRNetW18 --image paddle3d_caddn_kitti_000780.png --config_path $CONFIG_PATH

set +x
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,25 @@ int main(int argc, char* argv[]) {
}
if (config_info["backend"] == "paddle_trt" ||
config_info["backend"] == "trt") {
// use custom data to perform collect shapes.
option.trt_option.SetShape("images", {1, 3, 375, 1242},
{1, 3, 375, 1242}, {1, 3, 375, 1242});
option.trt_option.SetShape("trans_lidar_to_cam", {1, 4, 4},
{1, 4, 4}, {1, 4, 4});
option.trt_option.SetShape("trans_cam_to_img", {1, 3, 4},
{1, 3, 4}, {1, 3, 4});
// use custom data to perform collect shapes.
option.trt_option.SetShape("images", {1, 3, 375, 1242}, {1, 3, 375, 1242},
{1, 3, 375, 1242});
option.trt_option.SetShape("trans_lidar_to_cam", {1, 4, 4}, {1, 4, 4},
{1, 4, 4});
option.trt_option.SetShape("trans_cam_to_img", {1, 3, 4}, {1, 3, 4},
{1, 3, 4});
std::vector<float> image_data;
image_data.assign(im.data, im.data + 1*3*375*1242);
option.trt_option.SetInputData("trans_lidar_to_cam", lidar_data);
option.trt_option.SetInputData("trans_cam_to_img", cam_data);
option.trt_option.SetInputData("images", image_data);
image_data.assign(im.data, im.data + 1 * 3 * 375 * 1242);
option.trt_option.SetInputData("trans_lidar_to_cam", lidar_data);
option.trt_option.SetInputData("trans_cam_to_img", cam_data);
option.trt_option.SetInputData("images", image_data);
}
auto model_cadnn = vision::perception::Caddn(
model_file, params_file, "", option, model_format);
auto model_cadnn = vision::perception::Caddn(model_file, params_file, "",
option, model_format);
vision::PerceptionResult res;
// Run profiling
BENCHMARK_MODEL(model_cadnn, model_cadnn.Predict(im, cam_data, lidar_data, &res))
BENCHMARK_MODEL(model_cadnn,
model_cadnn.Predict(im, cam_data, lidar_data, &res))
std::cout << res.Str() << std::endl;
#endif

Expand Down
8 changes: 4 additions & 4 deletions benchmark/paddlex/benchmark_x86.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fi

# PaddleSeg
./benchmark_ppseg --model OCRNet_HRNetW48 --image ppseg_cityscapes_demo_512x512.png --config_path $CONFIG_PATH
./benchmark_ppseg --model PP-LiteSeg-STDC1 --image ppseg_cityscapes_demo_512x512.png --config_path $CONFIG_PATH
./benchmark_ppseg --model PP-LiteSeg-STDC1 --image ppseg_cityscapes_demo_512x512.png --config_path $CONFIG_PATH
./benchmark_ppseg --model SegFormer-B0 --image ppseg_cityscapes_demo_512x512.png --config_path $CONFIG_PATH
./benchmark_ppseg --model PP-MobileSeg-Base --image ppseg_ade_val_512x512.png --config_path $CONFIG_PATH

Expand All @@ -51,7 +51,7 @@ fi
# Paddle3D
./benchmark --model PETRv1_v99 --config_path $CONFIG_PATH --shapes "1,6,3,320,800:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --warmup 5 --repeat 20
./benchmark --model PETRv2_v99 --config_path $CONFIG_PATH --shapes "1,12,3,320,800:1,12,4,4:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --warmup 5 --repeat 20
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_caddn --model CADDN_OCRNet-HRNetW18 --image paddle3d_caddn_kitti_000780.png --config_path $CONFIG_PATH

set +x
set +x
2 changes: 1 addition & 1 deletion benchmark/paddlex/get_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ download PP-Structurev2-SLANet.tgz
download PP-Structurev2-vi-layoutxlm.tgz

# Paddle3D
download CADNN_OCRNet-HRNetW18.tgz
download CADDN_OCRNet-HRNetW18.tgz
download CenterPoint-Pillars-02Voxel.tgz
download PETRv1_v99.tgz
download PETRv2_v99.tgz
Expand Down
37 changes: 23 additions & 14 deletions fastdeploy/vision/perception/paddle3d/petr/petr.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ namespace vision {
namespace perception {

Petr::Petr(const std::string& model_file, const std::string& params_file,
const std::string& config_file, const RuntimeOption& custom_option,
const ModelFormat& model_format)
const std::string& config_file, const RuntimeOption& custom_option,
const ModelFormat& model_format)
: preprocessor_(config_file) {
valid_cpu_backends = {Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER};
Expand All @@ -41,29 +41,38 @@ bool Petr::Initialize() {
return true;
}

bool Petr::Predict(const cv::Mat& im, PerceptionResult* result) {
std::vector<PerceptionResult> results;
if (!BatchPredict({im}, &results)) {
return false;
}
if (results.size()) {
*result = std::move(results[0]);
}
return true;
bool Petr::Predict(const cv::Mat& images, PerceptionResult* results) {
FDERROR << "Petr inference only support 6(V1) or 12(V2) images" << std::endl;
return false;
}

bool Petr::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<PerceptionResult>* results) {
std::vector<PerceptionResult>* results) {
if ((images.size() != 6) && (images.size() != 12)) {
FDERROR << "Petr only support 6(V1) or 12(V2) images";
return false;
}
std::vector<FDMat> fd_images = WrapMat(images);

if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess the input image." << std::endl;
return false;
}


// Note: un-commented the codes below to show the debug info.
// reused_input_tensors_[0].PrintInfo();
// reused_input_tensors_[1].PrintInfo();
// reused_input_tensors_[2].PrintInfo();

reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
reused_input_tensors_[1].name = InputInfoOfRuntime(1).name;
reused_input_tensors_[2].name = InputInfoOfRuntime(2).name;
if (images.size() == 12) {
// for Petr V2 timestamp
reused_input_tensors_[2].name = InputInfoOfRuntime(2).name;
} else {
// for Petr V1
reused_input_tensors_.pop_back();
}

if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference by runtime." << std::endl;
Expand Down
132 changes: 25 additions & 107 deletions fastdeploy/vision/perception/paddle3d/petr/preprocessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "fastdeploy/vision/perception/paddle3d/petr/preprocessor.h"

#include <iostream>

#include "fastdeploy/function/concat.h"
Expand All @@ -31,78 +32,20 @@ PetrPreprocessor::PetrPreprocessor(const std::string& config_file) {

bool PetrPreprocessor::BuildPreprocessPipelineFromConfig() {
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}

// read for preprocess
bool has_permute = false;
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = true;
if (op["is_scale"]) {
is_scale = op["is_scale"].as<bool>();
}
std::string norm_type = "mean_std";
if (op["norm_type"]) {
norm_type = op["norm_type"].as<std::string>();
}
if (norm_type != "mean_std") {
std::fill(mean.begin(), mean.end(), 0.0);
std::fill(std.begin(), std.end(), 1.0);
}
mean_ = mean;
std_ = std;
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size() == 2,
"Require size of target_size be 2, but now it's %lu.",
target_size.size());
if (!keep_ratio) {
int width = target_size[0];
int height = target_size[1];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
std::vector<int> max_size;
if (max_target_size > 0) {
max_size.push_back(max_target_size);
max_size.push_back(max_target_size);
}
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_size));
}
} else if (op_name == "Permute") {
// Do nothing, do permute as the last operation
has_permute = true;
continue;
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
if (!disable_permute_) {
if (has_permute) {
// permute = cast<float> + HWC2CHW
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>());
}
}
processors_.push_back(std::make_shared<Resize>(800, 450));
processors_.push_back(std::make_shared<Crop>(0, 130, 800, 320));

std::vector<float> mean{103.530, 116.280, 123.675};
std::vector<float> std{57.375, 57.120, 58.395};
bool scale = false;
processors_.push_back(std::make_shared<Normalize>(mean, std, scale));
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>());

// Fusion will improve performance
FuseTransforms(&processors_);

input_k_data_ = cfg["k_data"].as<std::vector<float>>();
return true;
}

Expand All @@ -119,16 +62,16 @@ bool PetrPreprocessor::Apply(FDMatBatch* image_batch,
}
// There are 3 outputs, image, k_data, timestamp
outputs->resize(3);
int batch = static_cast<int>(image_batch->mats->size());
int num_cams = static_cast<int>(image_batch->mats->size());

// Allocate memory for k_data
(*outputs)[1].Resize({1, batch, 4, 4}, FDDataType::FP32);
(*outputs)[1].Resize({1, num_cams, 4, 4}, FDDataType::FP32);

// Allocate memory for image_data
(*outputs)[0].Resize({1, batch, 3, 320, 800}, FDDataType::FP32);
(*outputs)[0].Resize({1, num_cams, 3, 320, 800}, FDDataType::FP32);

// Allocate memory for timestamp
(*outputs)[2].Resize({1, batch}, FDDataType::FP32);
(*outputs)[2].Resize({1, num_cams}, FDDataType::FP32);

auto* image_ptr = reinterpret_cast<float*>((*outputs)[0].MutableData());

Expand All @@ -144,53 +87,28 @@ bool PetrPreprocessor::Apply(FDMatBatch* image_batch,
<< processors_[j]->Name() << "." << std::endl;
return false;
}
if (processors_[j]->Name() == "Resize") {
// crop and normalize after Resize
auto img = *(mat->GetOpenCVMat());
cv::Mat crop_img = img(cv::Range(130, 450), cv::Range(0, 800));
Normalize(&crop_img, mean_, std_, scale_);
FDMat fd_mat = WrapMat(crop_img);
image_batch->mats->at(i) = fd_mat;
}
}
}

for (int i = 0; i < batch / 2 * 4 * 4; ++i) {
input_k_data_.emplace_back(input_k_data_[i]);
for (int i = 0; i < num_cams / 2 * 4 * 4; ++i) {
input_k_data_.push_back(input_k_data_[i]);
}
memcpy(k_data_ptr, input_k_data_.data(), num_cams * 16 * sizeof(float));

memcpy(k_data_ptr, input_k_data_.data(), batch * 16 * sizeof(float));

std::vector<float> timestamp(batch, 0.0f);
for (int i = batch / 2; i < batch; ++i) {
std::vector<float> timestamp(num_cams, 0.0f);
for (int i = num_cams / 2; i < num_cams; ++i) {
timestamp[i] = 1.0f;
}
memcpy(timestamp_ptr, timestamp.data(), batch * sizeof(float));
memcpy(timestamp_ptr, timestamp.data(), num_cams * sizeof(float));

FDTensor* tensor = image_batch->Tensor();
FDTensor* tensor = image_batch->Tensor(); // [num_cams,3,320,800]
tensor->ExpandDim(0); // [num_cams,3,320,800] -> [1,num_cams,3,320,800]
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
tensor->Data(), tensor->device,
tensor->device_id);
return true;
}

void PetrPreprocessor::Normalize(cv::Mat* im, const std::vector<float>& mean,
const std::vector<float>& std, float& scale) {
if (scale) {
(*im).convertTo(*im, CV_32FC3, scale);
}
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
}
}
}

} // namespace perception
} // namespace vision
} // namespace fastdeploy
Loading

0 comments on commit ade27d2

Please sign in to comment.