Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FDTensor support GPU device #190

Merged
merged 13 commits into from
Sep 8, 2022
16 changes: 8 additions & 8 deletions csrc/fastdeploy/backends/ort/ort_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,20 @@ void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std:
size_t numel = info.GetElementCount();

if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
tensor->Allocate(info.GetShape(), FDDataType::FP32, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
tensor->Resize(numel * sizeof(float));
memcpy(static_cast<void*>(tensor->Data()), value.GetTensorData<void*>(),
heliqi marked this conversation as resolved.
Show resolved Hide resolved
heliqi marked this conversation as resolved.
Show resolved Hide resolved
numel * sizeof(float));
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
tensor->Allocate(info.GetShape(), FDDataType::INT32, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
tensor->Resize(numel * sizeof(int32_t));
memcpy(static_cast<void*>(tensor->Data()), value.GetTensorData<void*>(),
numel * sizeof(int32_t));
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
tensor->Allocate(info.GetShape(), FDDataType::INT64, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
tensor->Resize(numel * sizeof(int64_t));
memcpy(static_cast<void*>(tensor->Data()), value.GetTensorData<void*>(),
numel * sizeof(int64_t));
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
tensor->Allocate(info.GetShape(), FDDataType::FP64, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
tensor->Resize(numel * sizeof(double));
memcpy(static_cast<void*>(tensor->Data()), value.GetTensorData<void*>(),
numel * sizeof(double));
} else {
FDASSERT(
Expand Down
13 changes: 10 additions & 3 deletions csrc/fastdeploy/backends/paddle/paddle_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,23 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
}

TensorInfo PaddleBackend::GetInputInfo(int index) {
FDASSERT(index < NumInputs(), "The index: %d should less than the number of inputs: %d.", index, NumInputs());
FDASSERT(index < NumInputs(),
"The index: %d should less than the number of inputs: %d.", index,
NumInputs());
return inputs_desc_[index];
}

std::vector<TensorInfo> PaddleBackend::GetInputInfo() { return inputs_desc_; }

TensorInfo PaddleBackend::GetOutputInfo(int index) {
FDASSERT(index < NumOutputs(),
"The index: %d should less than the number of outputs %d.", index, NumOutputs());
"The index: %d should less than the number of outputs %d.", index,
NumOutputs());
return outputs_desc_[index];
}

std::vector<TensorInfo> PaddleBackend::GetOutputInfo() { return outputs_desc_; }

bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs) {
if (inputs.size() != inputs_desc_.size()) {
Expand All @@ -100,7 +107,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,

for (size_t i = 0; i < inputs.size(); ++i) {
auto handle = predictor_->GetInputHandle(inputs[i].name);
ShareTensorFromCpu(handle.get(), inputs[i]);
ShareTensorFromFDTensor(handle.get(), inputs[i]);
}

predictor_->Run();
Expand Down
7 changes: 6 additions & 1 deletion csrc/fastdeploy/backends/paddle/paddle_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ struct PaddleBackendOption {
std::vector<std::string> delete_pass_names = {};
};

// convert FD device to paddle place type
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device);

// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);

// Copy memory data from paddle_infer::Tensor to fastdeploy::FDTensor
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
Expand All @@ -72,6 +75,8 @@ class PaddleBackend : public BaseBackend {

TensorInfo GetInputInfo(int index);
TensorInfo GetOutputInfo(int index);
std::vector<TensorInfo> GetInputInfo();
std::vector<TensorInfo> GetOutputInfo();

private:
paddle_infer::Config config_;
Expand Down
28 changes: 21 additions & 7 deletions csrc/fastdeploy/backends/paddle/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,33 @@
#include "fastdeploy/backends/paddle/paddle_backend.h"

namespace fastdeploy {
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor) {
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device) {
if (device == Device::GPU) {
return paddle_infer::PlaceType::kGPU;
}
return paddle_infer::PlaceType::kCPU;
}

void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
FDTensor& fd_tensor) {
std::vector<int> shape(fd_tensor.shape.begin(), fd_tensor.shape.end());
tensor->Reshape(shape);
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
if (fd_tensor.dtype == FDDataType::FP32) {
tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU);
shape, place);
return;
} else if (fd_tensor.dtype == FDDataType::INT32) {
tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU);
shape, place);
return;
} else if (fd_tensor.dtype == FDDataType::INT64) {
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
shape, paddle_infer::PlaceType::kCPU);
shape, place);
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor.dtype).c_str());
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor.dtype).c_str());
}

void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
Expand All @@ -51,7 +61,8 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData()));
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor->dtype).c_str());
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor->dtype).c_str());
}

FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
Expand All @@ -65,7 +76,10 @@ FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
} else if (dtype == paddle_infer::UINT8) {
fd_dtype = FDDataType::UINT8;
} else {
FDASSERT(false, "Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.", int(dtype));
FDASSERT(
false,
"Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.",
int(dtype));
}
return fd_dtype;
}
Expand Down
8 changes: 6 additions & 2 deletions csrc/fastdeploy/backends/tensorrt/trt_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,12 @@ void TrtBackend::AllocateBufferInDynamicShape(
"Cannot find output: %s of tensorrt network from the original model.",
outputs_desc_[i].name.c_str());
auto ori_idx = iter->second;
std::vector<int64_t> shape(output_dims.d, output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].Allocate(shape, GetFDDataType(outputs_desc_[i].dtype), outputs_desc_[i].name);
(*outputs)[ori_idx].dtype = GetFDDataType(outputs_desc_[i].dtype);
(*outputs)[ori_idx].shape.assign(output_dims.d,
output_dims.d + output_dims.nbDims);
heliqi marked this conversation as resolved.
Show resolved Hide resolved
(*outputs)[ori_idx].name = outputs_desc_[i].name;
(*outputs)[ori_idx].Resize(Volume(output_dims) *
TrtDataTypeSize(outputs_desc_[i].dtype));
heliqi marked this conversation as resolved.
Show resolved Hide resolved
if ((*outputs)[ori_idx].Nbytes() >
outputs_buffer_[outputs_desc_[i].name].nbBytes()) {
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
Expand Down
53 changes: 53 additions & 0 deletions csrc/fastdeploy/core/allocate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif

#include <memory>
#include <new>
#include <numeric>
#include <string>
#include <vector>

class FDHostAllocator {
public:
bool operator()(void** ptr, size_t size) const {
*ptr = malloc(size);
return *ptr != nullptr;
}
};

class FDHostFree {
public:
void operator()(void* ptr) const { free(ptr); }
};

#ifdef WITH_GPU

class FDDeviceAllocator {
public:
bool operator()(void** ptr, size_t size) const {
return cudaMalloc(ptr, size) == cudaSuccess;
}
};

class FDDeviceFree {
public:
void operator()(void* ptr) const { cudaFree(ptr); }
};

#endif
91 changes: 66 additions & 25 deletions csrc/fastdeploy/core/fd_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,55 +25,61 @@ void* FDTensor::MutableData() {
if (external_data_ptr != nullptr) {
return external_data_ptr;
}
return data.data();
return buffer_;
}

void* FDTensor::Data() {
if (external_data_ptr != nullptr) {
if (device == Device::GPU) {
if (device == Device::GPU) {
#ifdef WITH_GPU
// need to copy cuda mem to cpu first
temporary_cpu_buffer.resize(Nbytes());
temporary_cpu_buffer.resize(Nbytes());
// need to copy cuda mem to cpu first
if (external_data_ptr != nullptr) {
FDASSERT(cudaMemcpy(temporary_cpu_buffer.data(), external_data_ptr,
Nbytes(), cudaMemcpyDeviceToHost) == 0,
"[ERROR] Error occurs while copy memory from GPU to CPU");
return temporary_cpu_buffer.data();
#else
FDASSERT(false,
"The FastDeploy didn't compile under -DWITH_GPU=ON, so this is "
"an unexpected problem happend.");
#endif

} else {
return external_data_ptr;
FDASSERT(cudaMemcpy(temporary_cpu_buffer.data(), buffer_, Nbytes(),
cudaMemcpyDeviceToHost) == 0,
"[ERROR] Error occurs while buffer copy memory from GPU to CPU");
}
return temporary_cpu_buffer.data();
#else
FDASSERT(false,
"The FastDeploy didn't compile under -DWITH_GPU=ON, so this is "
"an unexpected problem happend.");
#endif
}
return data.data();
return MutableData();
}

const void* FDTensor::Data() const {
if (external_data_ptr != nullptr) {
return external_data_ptr;
}
return data.data();
}
// const void* FDTensor::Data() const {
// if (external_data_ptr != nullptr) {
// return external_data_ptr;
// }
// return data.data();
// }

void FDTensor::SetExternalData(const std::vector<int64_t>& new_shape,
const FDDataType& data_type, void* data_buffer) {
const FDDataType& data_type, void* data_buffer,
const Device& new_device) {
dtype = data_type;
shape.assign(new_shape.begin(), new_shape.end());
external_data_ptr = data_buffer;
device = new_device;
}

void FDTensor::Allocate(const std::vector<int64_t>& new_shape,
const FDDataType& data_type,
const std::string& tensor_name) {
const std::string& tensor_name,
const Device& new_device) {
dtype = data_type;
name = tensor_name;
shape.assign(new_shape.begin(), new_shape.end());
int unit = FDDataTypeSize(data_type);
int total_size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
data.resize(total_size * unit);
device = new_device;
size_t nbytes = Nbytes();
FDASSERT(AllocFn(nbytes),
"The FastDeploy FDTensor allocate cpu memory error");
}
heliqi marked this conversation as resolved.
Show resolved Hide resolved

int FDTensor::Nbytes() const { return Numel() * FDDataTypeSize(dtype); }
Expand All @@ -82,6 +88,41 @@ int FDTensor::Numel() const {
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
}

void FDTensor::Resize(size_t new_nbytes) {
size_t nbytes = Nbytes();
if (new_nbytes > nbytes) {
FreeFn();
AllocFn(new_nbytes);
}
}

void FDTensor::Resize(const std::vector<int64_t>& new_shape) {
int numel = Numel();
int new_numel = std::accumulate(new_shape.begin(), new_shape.end(), 1,
std::multiplies<int>());
shape.assign(new_shape.begin(), new_shape.end());
if (new_numel > numel) {
FreeFn();
size_t nbytes = new_numel * FDDataTypeSize(dtype);
AllocFn(nbytes);
}
}

void FDTensor::Resize(const std::vector<int64_t>& new_shape,
const FDDataType& data_type, const Device& new_device) {
heliqi marked this conversation as resolved.
Show resolved Hide resolved
size_t nbytes = Nbytes();
int new_nbytes = std::accumulate(new_shape.begin(), new_shape.end(), 1,
std::multiplies<int>()) *
FDDataTypeSize(dtype);
shape.assign(new_shape.begin(), new_shape.end());
dtype = data_type;
if (new_nbytes > nbytes) {
FreeFn();
AllocFn(new_nbytes);
}
device = new_device;
}

template <typename T>
void CalculateStatisInfo(void* src_ptr, int size, double* mean, double* max,
double* min) {
Expand Down
Loading