Skip to content

Commit

Permalink
refine docstring of get_buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Sep 27, 2024
1 parent b3a6408 commit 15a7e75
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
14 changes: 10 additions & 4 deletions source/api_cc/include/DeepPotPD.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,20 @@ class DeepPotPD : public DeepPotBase {
void get_type_map(std::string& type_map);

/**
* @brief Get the type map (element name of the atom types) of this model.
* @param[out] type_map The type map of this model.
* @brief Get the buffer of this model.
* @param[in] buffer_name Buffer name.
* @param[out] buffer_array Buffer array.
**/
template<typename BUFFERTYPE>
void get_buffer(const std::string &buffer_name, std::vector<BUFFERTYPE> &buffer_arr);
void get_buffer(const std::string &buffer_name, std::vector<BUFFERTYPE> &buffer_array);

/**
* @brief Get the buffer of this model.
* @param[in] buffer_name Buffer name.
* @param[out] buffer_scalar Buffer scalar.
**/
template<typename BUFFERTYPE>
void get_buffer(const std::string &buffer_name, BUFFERTYPE &buffer_arr);
void get_buffer(const std::string &buffer_name, BUFFERTYPE &buffer_scalar);

/**
* @brief Get whether the atom dimension of aparam is nall instead of fparam.
Expand Down
34 changes: 17 additions & 17 deletions source/api_cc/src/DeepPotPD.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,23 +379,6 @@ template void DeepPotPD::compute<float, std::vector<ENERGYTYPE>>(
const std::vector<float>& aparam,
const bool atomic);

/* general function except for string buffer */
template<typename BUFFERVTYPE>
void DeepPotPD::get_buffer(const std::string &buffer_name, std::vector<BUFFERVTYPE> &buffer_arr) {
auto buffer_tensor = predictor->GetOutputHandle(buffer_name);
auto buffer_shape = buffer_tensor->shape();
int buffer_size = std::accumulate(buffer_shape.begin(), buffer_shape.end(), 1, std::multiplies<int>());
buffer_arr.resize(buffer_size);
buffer_tensor->CopyToCpu(buffer_arr.data());
}

template<typename BUFFERTYPE>
void DeepPotPD::get_buffer(const std::string &buffer_name, BUFFERTYPE &buffer) {
std::vector<BUFFERTYPE> buffer_arr(1);
DeepPotPD::get_buffer<BUFFERTYPE>(buffer_name, buffer_arr);
buffer = buffer_arr[0];
}

/* type_map is regarded as a special string buffer
that need to be postprocessed */
void DeepPotPD::get_type_map(std::string& type_map) {
Expand All @@ -410,6 +393,23 @@ void DeepPotPD::get_type_map(std::string& type_map) {
}
}

/* general function except for string buffer */
template<typename BUFFERTYPE>
void DeepPotPD::get_buffer(const std::string &buffer_name, std::vector<BUFFERTYPE> &buffer_array) {
auto buffer_tensor = predictor->GetOutputHandle(buffer_name);
auto buffer_shape = buffer_tensor->shape();
int buffer_size = std::accumulate(buffer_shape.begin(), buffer_shape.end(), 1, std::multiplies<int>());
buffer_array.resize(buffer_size);
buffer_tensor->CopyToCpu(buffer_array.data());
}

template<typename BUFFERTYPE>
void DeepPotPD::get_buffer(const std::string &buffer_name, BUFFERTYPE &buffer_scalar) {
std::vector<BUFFERTYPE> buffer_array(1);
DeepPotPD::get_buffer<BUFFERTYPE>(buffer_name, buffer_array);
buffer_scalar = buffer_array[0];
}

// forward to template method
void DeepPotPD::computew(std::vector<double>& ener,
std::vector<double>& force,
Expand Down

0 comments on commit 15a7e75

Please sign in to comment.