Skip to content

Commit

Permalink
[Bug Fix] fix trt backend page-locked error (#2095)
Browse files Browse the repository at this point in the history
* [Bug Fix] fix trt backend page-locked error

* Update trt_backend.cc
  • Loading branch information
DefTruth committed Jul 11, 2023
1 parent 4c1e80b commit cf1ff20
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions fastdeploy/runtime/backends/tensorrt/trt_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,16 +470,32 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
if (item.dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
std::vector<int32_t> casted_data(data, data + item.Numel());
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
static_cast<void*>(casted_data.data()),
item.Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
// FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
// static_cast<void*>(casted_data.data()),
// item.Nbytes() / 2, cudaMemcpyHostToDevice,
// stream_) == 0,
// "Error occurs while copy memory from CPU to GPU.");
// WARN: For cudaMemcpyHostToDevice direction, cudaMemcpyAsync need page-locked host
// memory to avoid any overlap to occur. The page-locked feature need by cudaMemcpyAsync
// may not guarantee by FDTensor now. Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creation-and-destruction
FDASSERT(cudaMemcpy(inputs_device_buffer_[item.name].data(),
static_cast<void*>(casted_data.data()),
item.Nbytes() / 2, cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");
} else {
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
item.Data(), item.Nbytes(),
cudaMemcpyHostToDevice, stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
// FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
// item.Data(), item.Nbytes(),
// cudaMemcpyHostToDevice, stream_) == 0,
// "Error occurs while copy memory from CPU to GPU.");
// WARN: For cudaMemcpyHostToDevice direction, cudaMemcpyAsync need page-locked host
// memory to avoid any overlap to occur. The page-locked feature need by cudaMemcpyAsync
// may not guarantee by FDTensor now. Reference:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creation-and-destruction
FDASSERT(cudaMemcpy(inputs_device_buffer_[item.name].data(),
item.Data(), item.Nbytes(),
cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");
}
}
// binding input buffer
Expand Down

0 comments on commit cf1ff20

Please sign in to comment.