-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update HMC API #541
update HMC API #541
Conversation
ppsci/probability/hmc.py
Outdated
from paddle.distribution import Normal | ||
from paddle.distribution import Uniform | ||
|
||
from ppsci.utils import set_random_seed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
==> from ppsci import utils
ppsci/probability/hmc.py
Outdated
from paddle.distribution import Normal | ||
from paddle.distribution import Uniform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
==> from paddle import distribution
ppsci/probability/hmc.py
Outdated
@@ -12,95 +12,149 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
|
|||
from collections import OrderedDict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python>=3.7 的dict应该是默认与添加顺序保持一致的,可以不使用这个OrderedDict
ppsci/probability/hmc.py
Outdated
""" | ||
|
||
def __init__(self, tensor: paddle.Tensor): | ||
self.tensor = tensor | ||
def __init__(self, tensor_dict: OrderedDict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
==> Dict[str, paddle.Tensor]
num_warmup_steps: int = 0, | ||
random_seed: int = 1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring添加一下这两个参数
ppsci/probability/hmc.py
Outdated
num_warmup_steps: int = 0, | ||
random_seed: int = 1024, | ||
): | ||
self.dist = distribution_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
==> self.distribution_fn
ppsci/probability/hmc.py
Outdated
q0_nlp = -self.dist(**q0) | ||
q1_nlp = -self.dist(**q1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后缀_nlp表示什么意思呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是之前pr的朋友写的,我没改,可能是negative log probability吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是之前pr的朋友写的,我没改,可能是negative log probability吧。
这个命名感觉不太好,可以在上方加一个带有全称的注释说明nlp的意义
ppsci/probability/hmc.py
Outdated
|
||
acceptance = paddle.minimum( | ||
paddle.to_tensor(1.0), paddle.exp((q0_nlp + p0_nlp) - (p1_nlp + q1_nlp)) | ||
) | ||
|
||
# whether accept the proposed state position | ||
event = paddle.uniform(shape=[], min=0, max=1) | ||
event = self._rv_unif.sample([1]).squeeze() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle已经支持0-D tensor,可以试试直接使用[]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
完善一下每个函数的type hint,包括入参和返回类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
hi, @NKNaN
|
Describe
Future Work
如果需要达到和Pyro等一样的速度需要在采样时加入调整步长的功能。