Skip to content

Commit

Permalink
【Hackathon 5th No.25】add gammaln api (update) (#791)
Browse files Browse the repository at this point in the history
* modify documentation to align with the implementation code.

* update
  • Loading branch information
GreatV authored Dec 28, 2023
1 parent 0514e21 commit f867979
Showing 1 changed file with 67 additions and 1 deletion.
68 changes: 67 additions & 1 deletion rfcs/APIs/20230925_api_design_for_gammaln.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,34 @@ double lgam_sgn(double x, int *sign)
Pytorch中直接使用的是C++标准库中的 `std::lgamma`,代码如下:
- CPU版本
```C++
Vectorized<T> map(T (*const f)(T)) const {
Vectorized<T> ret;
for (int64_t i = 0; i != size(); i++) {
ret[i] = f(values[i]);
}
return ret;
}
Vectorized<T> lgamma() const {
return map(std::lgamma);
}
```

- GPU版本

```C++
const auto lgamma_string = jiterator_stringify(
template <typename T>
T lgamma_kernel(T a) {
return lgamma(a);
}
); // lgamma_string
```


## 四、对比分析

Scipy 基于 cephes 库实现,Pytorch 基于 C++ 标准库实现。Scipy 中的实现与 Python 标准库中的 `math.lgamma` 一致。Pytorch 中直接使用的是C++标准库中的 `std::lgamma`。而飞桨中暂无 `paddle.gamma` 函数,不好通过Python端实现,因此,需要通过C++端实现。
Expand All @@ -200,12 +222,56 @@ Scipy 基于 cephes 库实现,Pytorch 基于 C++ 标准库实现。Scipy 中

<!-- 参考:[飞桨API 设计及命名规范](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/api_design_guidelines_standard_cn.html) -->

API设计为 `paddle.gammaln(x)`。其中,`x``Tensor` 类型。`Tensor.gammaln()` 为 Tensor 的方法版本。
API设计为 `paddle.gammaln(x, name=None)`。其中,`x``Tensor` 类型。`Tensor.gammaln()` 为 Tensor 的方法版本。

### API实现方案

参考 PyTorch 采用C++ 实现,实现位置为 Paddle repo `python/paddle/tensor/math.py` 目录。并在 python/paddle/tensor/init.py 中,添加 `gammaln` API,以支持 `Tensor.gammaln` 的调用方式。头文件放在 `paddle/phi/kernels` 目录,cc 文件在 `paddle/phi/kernels/cpu` 目录, cu文件 `paddle/phi/kernels/gpu` 目录。

其中,`Gammaln` 的前向计算实现核心内容可为:

```c++
template <typename T>
struct GammalnFunctor {
GammalnFunctor(const T* x, T* output, int64_t numel)
: x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(x_[idx]);
output_[idx] = static_cast<T>(std::lgamma(mp_x));
}

private:
const T* x_;
T* output_;
int64_t numel_;
};
```
算子的反向计算实现核心内容可为:
```c++
template <typename T>
struct GammalnGradFunctor {
GammalnGradFunctor(const T* dout, const T* x, T* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_dout = static_cast<MT>(dout_[idx]);
const MT mp_x = static_cast<MT>(x_[idx]);
output_[idx] = static_cast<T>(mp_dout * digamma<MT>(mp_x));
}
private:
const T* dout_;
const T* x_;
T* output_;
int64_t numel_;
};
```

## 六、测试和验收的考量

<!-- 参考:[新增API 测试及验收规范](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/api_accpetance_criteria_cn.html) -->
Expand Down

0 comments on commit f867979

Please sign in to comment.