-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Zero-Dim] support input 0D Tensor for distribution transform api #47677
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,7 @@ | |
# limitations under the License. | ||
|
||
import enum | ||
import functools | ||
import math | ||
import operator | ||
import typing | ||
|
||
import paddle | ||
|
@@ -401,7 +399,7 @@ def _inverse(self, y): | |
return -y, y | ||
|
||
def _inverse_log_det_jacobian(self, y): | ||
zero = paddle.zeros([1], dtype=y.dtype) | ||
zero = paddle.zeros([], dtype=y.dtype) | ||
return zero, zero | ||
|
||
@property | ||
|
@@ -872,12 +870,16 @@ def __init__(self, in_event_shape, out_event_shape): | |
f"Squence[int], but got 'in_event_shape': {in_event_shape}, " | ||
f"'out_event_shape': {out_event_shape}" | ||
) | ||
if functools.reduce(operator.mul, in_event_shape) != functools.reduce( | ||
operator.mul, out_event_shape | ||
): | ||
in_size = 1 | ||
for e in in_event_shape: | ||
in_size *= e | ||
out_size = 1 | ||
for e in out_event_shape: | ||
out_size *= e | ||
if in_size != out_size: | ||
raise ValueError( | ||
f"The numel of 'in_event_shape' should be 'out_event_shape', " | ||
f"but got {functools.reduce(operator.mul, in_event_shape)}!={functools.reduce(operator.mul, out_event_shape)}" | ||
f"but got {in_size}!={out_size}" | ||
) | ||
|
||
self._in_event_shape = tuple(in_event_shape) | ||
|
@@ -917,7 +919,9 @@ def _forward_shape(self, shape): | |
raise ValueError( | ||
f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}" | ||
) | ||
if shape[-len(self._in_event_shape) :] != self._in_event_shape: | ||
if tuple(shape[-len(self._in_event_shape) :]) != tuple( | ||
self._in_event_shape | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议只读类型用tuple There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
raise ValueError( | ||
f"Event shape mismatch, expected: {self._in_event_shape}, but got {shape[-len(self._in_event_shape):]}" | ||
) | ||
|
@@ -930,7 +934,9 @@ def _inverse_shape(self, shape): | |
raise ValueError( | ||
f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}" | ||
) | ||
if shape[-len(self._out_event_shape) :] != self._out_event_shape: | ||
if tuple(shape[-len(self._out_event_shape) :]) != tuple( | ||
self._out_event_shape | ||
): | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
f"Event shape mismatch, expected: {self._out_event_shape}, but got {shape[-len(self._out_event_shape):]}" | ||
) | ||
|
@@ -939,7 +945,7 @@ def _inverse_shape(self, shape): | |
) | ||
|
||
def _forward_log_det_jacobian(self, x): | ||
# paddle.zeros not support zero dimension Tensor. | ||
# TODO(zhouwei): should not set shape to [1], which is [] | ||
shape = x.shape[: x.dim() - len(self._in_event_shape)] or [1] | ||
return paddle.zeros(shape, dtype=x.dtype) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,7 +103,7 @@ def wrapper(f, instance=None): | |
frame_locals[name].__doc__ = doc_func(f, num, p) | ||
|
||
# Delete original patches to prevent new function from evaluating | ||
# original patching object as well as re-constructed patches. | ||
# original patching object as well as re-constrfucted patches. | ||
delete_patches_if_need(f) | ||
|
||
f.__test__ = False | ||
Comment on lines
+106
to
109
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 貌似有 typo |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个修改的目的是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0D Tensor的shape为[],numel为1,用functools.reduce无法实现,会报空列表的错误,所以就自行实现了