Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.19】Add ContinuousBernoulli and MultivariateNormal API #58004

Merged
merged 29 commits into from
Dec 18, 2023

Conversation

NKNaN
Copy link
Contributor

@NKNaN NKNaN commented Oct 11, 2023

PR types

New features

PR changes

APIs

Description

Add ContinuousBernoulli and MultivariateNormal API

@paddle-bot
Copy link

paddle-bot bot commented Oct 11, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@NKNaN
Copy link
Contributor Author

NKNaN commented Oct 11, 2023

Math Derivation for entropy of the Continuous Bernoulli distribution and kl_divergence of 2 Continuous Bernoulli distributions:

  • entropy:
$$\begin{aligned} H &= -\int_x C(\lambda) \lambda^x (1-\lambda)^{1-x} \log\{C(\lambda) \lambda^x (1-\lambda)^{1-x}\} dx \\\ & = -\int_0^1 C \lambda^x (1-\lambda)^{1-x} \left[ \log C + x \log \lambda + (1 -x) \log (1 - \lambda)\right] dx \\\ & = -\left[ C \log C \int_0^1 \lambda^x (1-\lambda)^{1-x} dx + C \log \lambda \int_0^1 x \lambda^x (1-\lambda)^{1-x} + C \log(1 - \lambda) \int_0^1 (1-x) \lambda^x (1-\lambda)^{1-x} \right] \\\ & = - \left[ \log C + \mathbb{E}(X) \log \lambda + \mathbb{E}(1 - X) \log(1 - \lambda) \right] \\\ & = -\log C + \left[ \log (1 - \lambda) -\log \lambda \right] \mathbb{E}(X) - \log(1 - \lambda) \end{aligned}$$
  • kl_divergence:
$$\begin{aligned} \mathcal{D}_{KL}(p_1|| p_2) &= \int_x p_1(x)\log\frac{p_1(x)}{p_2(x)} dx \\\ & = \int_0^1 C_1 \lambda_1^x (1-\lambda_1)^{1-x} \{\log[C_1 \lambda_1^x (1-\lambda_1)^{1-x}] - \log[C_2 \lambda_2^x (1-\lambda_2)^{1-x}]\} dx \\\ & = -H - [C_1 \log C_2 \int_0^1 \lambda_1^x (1-\lambda_1)^{1-x} dx + C_1 \log \lambda_2 \int_0^1 x \lambda_1^x (1-\lambda_1)^{1-x} dx + C_1 \log (1-\lambda_2) \int_0^1 (1-x) \lambda_1^x (1-\lambda_1)^{1-x} dx] \\\ & = -H - [\log C_2 + \log \lambda_2 \mathbb{E}_1(X) + \log (1-\lambda_2) \mathbb{E}_1(1-X) ] \\\ & = - H - \{\log C_2 + [\log \lambda_2 - \log (1-\lambda_2)] \mathbb{E}_1(X) + \log (1-\lambda_2) \} \end{aligned}$$

@NKNaN
Copy link
Contributor Author

NKNaN commented Oct 11, 2023

Math Derivation for entropy of the Multivariate Normal distribution and kl_divergence of 2 Multivariate Normal distributions:

  • entropy:
$$\begin{aligned} H &= -\int_x f(x) \log f(x) dx \\\ & = -\int_{x \in \mathbb{R}^n} f(x) \{ -\frac{n}{2}\log(2\pi) -\frac{1}{2} (x-\mu)^{\intercal} \Sigma^{-1} (x-\mu) - \frac{1}{2}\log (\det\Sigma) \} dx \\\ & = -\int_{x \in \mathbb{R}^n} f(x) \{ -\frac{n}{2}\log(2\pi) -\frac{1}{2} [A^{-1}(x-\mu)]^{\intercal}[A^{-1}(x-\mu)] - \log (\det A) \} dx \\\ & = \frac{n}{2} \log(2\pi) + \log {\det A} + \frac{1}{2}\int_{x \in \mathbb{R}^n} [A^{-1}(x-\mu)]^{\intercal}[A^{-1}(x-\mu)] f(x) dx\\\ & = \frac{n}{2} \log(2\pi) + \log {\det A} + \frac{1}{2} \mathbb{E}[(X-\mu)^{\intercal} \Sigma^{-1} (X - \mu)] \\\ & = \frac{n}{2} \log(2\pi) + \log {\det A} + \frac{1}{2} \mathbb{E}[tr[(X-\mu)^{\intercal} \Sigma^{-1} (X - \mu)] ] \\\ & = \frac{n}{2} \log(2\pi) + \log {\det A} + \frac{1}{2} \mathbb{E}[tr[\Sigma^{-1} (X - \mu) (X-\mu)^{\intercal}]] \\\ & = \frac{n}{2} \log(2\pi) + \log {\det A} + \frac{1}{2} tr[\mathbb{E}[\Sigma^{-1} (X - \mu) (X-\mu)^{\intercal}]] \\\ & = \frac{n}{2} \log(2\pi) + \log {\det A} + \frac{1}{2} tr[\Sigma^{-1} \mathbb{E}[ (X - \mu) (X-\mu)^{\intercal}]] \\\ & = \frac{n}{2} \log(2\pi) + \log {\det A} + \frac{1}{2} tr[\Sigma^{-1} \Sigma] \\\ & = \frac{n}{2} \log(2\pi) + \log {\det A} + \frac{n}{2} \end{aligned}$$
  • kl_divergence:
$$\begin{aligned} \mathcal{D}_{KL}(f_1|| f_2) &= \int_x f_1(x)\log\frac{f_1(x)}{f_2(x)} dx \\\ & = \int_{x \in \mathbb{R}^n} f_1(x)\left\{\left[ -\frac{n}{2} \log(2\pi) - \log(\det A_1) - \frac{1}{2}(x-\mu_1)^{\intercal} \Sigma_1^{-1} (x - \mu_1) \right] + \left[ \frac{n}{2} \log(2\pi) + \log(\det A_2) + \frac{1}{2}(x-\mu_2)^{\intercal} \Sigma_21^{-1} (x - \mu_2)\right]\right\} dx \\\ & = \log(\det A_2) - \log(\det A_1) +\frac{1}{2}\mathbb{E}_1[(X-\mu_2)^{\intercal} \Sigma_2^{-1} (X - \mu_2)] -\frac{n}{2} \\\ & = \log(\det A_2) - \log(\det A_1) +\frac{1}{2}tr [\Sigma_2^{-1}\mathbb{E}_1[ (X - \mu_2) (X-\mu_2)^{\intercal} ]] -\frac{n}{2} \\\ & = \log(\det A_2) - \log(\det A_1) -\frac{n}{2} +\frac{1}{2}tr [\Sigma_2^{-1}\mathbb{E}_1[ XX^{\intercal} -X \mu_2^{\intercal} - \mu_2 X^{\intercal} + \mu_2\mu_2^{\intercal}]] \\\ & = \log(\det A_2) - \log(\det A_1) -\frac{n}{2} +\frac{1}{2}tr [\Sigma_2^{-1} [ Var_1(X) + \mathbb{E}_1(X)\mathbb{E}_1(X)^{\intercal} -\mu_1\mu_2^{\intercal} - \mu_2 \mu_1^{\intercal} + \mu_2\mu_2^{\intercal}]] \\\ & = \log(\det A_2) - \log(\det A_1) -\frac{n}{2} +\frac{1}{2}tr [\Sigma_2^{-1} [ \Sigma_1 + \mu_1\mu_1^{\intercal} -\mu_1\mu_2^{\intercal} - \mu_2 \mu_1^{\intercal} + \mu_2\mu_2^{\intercal}]] \\\ & = \log(\det A_2) - \log(\det A_1) -\frac{n}{2} +\frac{1}{2}tr [\Sigma_2^{-1} \Sigma_1 + (\mu_1 - \mu_2)^{\intercal} \Sigma_2^{-1} (\mu_1 - \mu_2)] \\\ & = \log(\det A_2) - \log(\det A_1) -\frac{n}{2} +\frac{1}{2}[tr [\Sigma_2^{-1} \Sigma_1] + (\mu_1 - \mu_2)^{\intercal} \Sigma_2^{-1} (\mu_1 - \mu_2)] \\\ \end{aligned}$$

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Oct 19, 2023

Sorry to inform you that 064e8a9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link

paddle-ci-bot bot commented Nov 1, 2023

Sorry to inform you that 4ce267e's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@luotao1
Copy link
Contributor

luotao1 commented Nov 17, 2023

test_distribution_continuous_bernoulli_static 单测没过

Copy link

paddle-ci-bot bot commented Nov 30, 2023

Sorry to inform you that 8b913d3's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

# convert type
if isinstance(probability, (float, int)):
probability = [probability]
probability = paddle.to_tensor(probability, dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果probability本身是Tensor,这里会改变probability的数据类型。别入用户传入p是fp32, 默认数据类型是fp64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

cxxly
cxxly previously approved these changes Dec 12, 2023
Copy link
Contributor

@cxxly cxxly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeff41404
Copy link
Contributor

code is fine, please add link of rfc in description above

@jeff41404
Copy link
Contributor

the design of ContinuousBernoulli in rfc needs to be consistent with the code

@jeff41404
Copy link
Contributor

the rfc of MultivariateNormal have same issue

@NKNaN
Copy link
Contributor Author

NKNaN commented Dec 13, 2023

code is fine, please add link of rfc in description above

添加了rfc链接,rfc设计文档需要做一些修改,已提相应pr

@luotao1
Copy link
Contributor

luotao1 commented Dec 13, 2023

对应中文文档可以提上来

@NKNaN
Copy link
Contributor Author

NKNaN commented Dec 13, 2023

对应中文文档可以提上来

已提中文pr

又修改了一下对应的英文文档

Comment on lines 51 to 57
Args:
probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1],
which characterize the shape of the pdf. If the input data type is int or float, the data type of
`probability` will be convert to a 1-D Tensor the paddle global default dtype.
eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region
would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The
default value is 0.02.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Args:
probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1],
which characterize the shape of the pdf. If the input data type is int or float, the data type of
`probability` will be convert to a 1-D Tensor the paddle global default dtype.
eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region
would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The
default value is 0.02.
Args:
probability(int|float|Tensor): The probability of Continuous Bernoulli distribution between [0, 1],
which characterize the shape of the pdf. If the input data type is int or float, the data type of
`probability` will be convert to a 1-D Tensor the paddle global default dtype.
eps(float): Specify the bandwith of the unstable calculation region near 0.5. The unstable calculation region
would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The
default value is 0.02.

对于 Args 下的每个参数,同一个参数的描述换行需要加下缩进

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

r"""The Continuous Bernoulli distribution with parameter: `probability` characterizing the shape of the density function.
The Continuous Bernoulli distribution is defined on [0, 1], and it can be viewed as a continuous version of the Bernoulli distribution.

[1] Loaiza-Ganem, G., & Cunningham, J. P. The continuous Bernoulli: fixing a pervasive error in variational autoencoders. 2019.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能不能直接贴上论文的连接?引用方式参考 如何让文档相互引用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

would be [0.5 - eps, 0.5 + eps], where the calculation is approximated by talyor expansion. The
default value is 0.02.

Examples:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


In the above equation:

* :math:\Omega: is the support of the distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* :math:\Omega: is the support of the distribution.
* :math:`\Omega` is the support of the distribution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这篇文档和 continuous_bernoulli.py 有同样的问题,不赘述了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对应处已修改

@cxxly
Copy link
Contributor

cxxly commented Dec 14, 2023

抱歉,再补充个Comment,ContinuousBernoulli 签名和PyTorch保持一致

@NKNaN
Copy link
Contributor Author

NKNaN commented Dec 14, 2023

抱歉,再补充个Comment,ContinuousBernoulli 签名和PyTorch保持一致

已修改

sunzhongkai588
sunzhongkai588 previously approved these changes Dec 14, 2023
Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~

doc-preview CI 中发现新的 system message 错误, @ooooo-create 之后全量的再检查一遍相关错误并修复叭

[0.20103608, 0.07641447])
"""

def __init__(self, probs=None, lims=(0.499, 0.501)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用加None,probs是必选参数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

@cxxly cxxly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeff41404
Copy link
Contributor

There are spelling and capitalization issues in the link of rfc, eg should be 20230927_api_design_for_ContinuousBernoulli.md instead of 20230927_api_design_for_continuous_bernoulli.md, cause error of 404

@NKNaN
Copy link
Contributor Author

NKNaN commented Dec 18, 2023

There are spelling and capitalization issues in the link of rfc, eg should be 20230927_api_design_for_ContinuousBernoulli.md instead of 20230927_api_design_for_continuous_bernoulli.md, cause error of 404

Links have been updated

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 1208eab into PaddlePaddle:develop Dec 18, 2023
28 of 29 checks passed
HermitSun pushed a commit to HermitSun/Paddle that referenced this pull request Dec 21, 2023
…PI (PaddlePaddle#58004)

* add api and test

* add kl-div registrition for cb and mvn

* fix docs annd test

* fix test

* fix test

* fix mvn test coverage

* fix docs

* update docs

* update cb and mvn

* fix mvn test

* fix test

* fix test

* fix test

* fix test

* fix unstable region calculation

* fix test

* update dtype convertion and tests

* fix test

* fix test

* fix test

* refine docs

* update docs

* update docs

* update docs

* update cb api

* increase cb static test timeout

* fix test time

* fix test

* update cb
@NKNaN NKNaN deleted the ayase/develop2 branch February 2, 2024 04:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants