Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting for linalg functions that accept an axis #617

Closed
asmeurer opened this issue Apr 4, 2023 · 3 comments · Fixed by #740
Closed

Broadcasting for linalg functions that accept an axis #617

asmeurer opened this issue Apr 4, 2023 · 3 comments · Fixed by #740
Assignees
Labels
Narrative Content Narrative documentation content. topic: Linear Algebra Linear algebra.
Milestone

Comments

@asmeurer
Copy link
Member

asmeurer commented Apr 4, 2023

Two linear algebra functions allow contraction over an arbitrary axis, cross and vecdot.

These APIs currently specify:

x2: Must be compatible with x1 for all non-compute axes (see Broadcasting).

as well as

The compute axis (dimension) must not be broadcasted.

This is ambiguous however when the contracted axis is a dimension that is created by broadcasting, for instance

vecdot(zeros((1, 4, 5)), zeros((4, 5)), axis=0)

Here axis=0 applied to the first dimension would be size 1.

I think this case should be disallowed.

Additionally, axis is ambiguous. It isn't clear if it should refer to the axis before or after broadcasting:

axis (int) – the axis (dimension) of x1 and x2 containing the vectors for which to compute the cross product. Must be an integer on the interval [-N, N), where N is the rank (number of dimensions) of the shape determined according to Broadcasting. If specified as a negative integer, the function must determine the axis along which to compute the cross product by counting backward from the last dimension (where -1 refers to the last dimension). By default, the function must compute the cross product over the last axis. Default: -1.

This is of particular interest if axis >= 0.

NumPy appears to refer to the axis before broadcasting:

>>> np.cross(np.zeros((3, 1, 2)), np.zeros((3,)), axis=0)
array([[[0., 0.]],

       [[0., 0.]],

       [[0., 0.]]])

In fact, these two arrays aren't strictly broadcast compatible. What NumPy does is move the axis dimension of x1 and x2 to the end of the arrays, then broadcasts x1[..., 0] and x2[..., 0]. Effectively:

x1 = moveaxis(x1, axis, -1)
x2 = moveaxis(x2, axis, -1)
if a.shape[-1] != 3 or b.shape[-1] != 3:
        raise ValueError("incompatible dimensions for cross product")

shape = broadcast(x1[..., 0], x2[..., 0]).shape

In other words, the arrays should be broadcast compatible after removing axis from the shape (and we should have x1.shape[axis] == x2.shape[axis] == 3).

NumPy doesn't have vecdot yet, but it should obviously work the same (the only difference being the contracted axis can have any size in vecdot, not just 3, and unlike cross, in vecdot the contracted axis is removed from the resulting shape). torch.linalg.cross doesn't appear to support any broadcasting.

My implementations of vecdot in numpy.array_api and array-api-compat have been using the idea that axis refers to the axis broadcasting and allowing an added broadcasted axis. But I think this should be changed to work like np.cross. The numpy.array_api and array_api_compat.numpy cross implementations just reuse np.cross and therefore use those semantics (I didn't realize til now that we weren't actually testing any broadcasting rules for cross in the test suite).

This was discussed at data-apis/array-api-compat#35 (comment) (CC @lezcano).

Finally, note that tensordot doesn't have this issue because the axes are specified for each array separately.

@kgryte kgryte added Narrative Content Narrative documentation content. topic: Linear Algebra Linear algebra. labels Jun 29, 2023
@kgryte kgryte added this to the v2023 milestone Jun 29, 2023
@kgryte
Copy link
Contributor

kgryte commented Nov 6, 2023

Additionally, axis is ambiguous. It isn't clear if it should refer to the axis before or after broadcasting.

I don't believe this is correct. The spec, as you quote above, says "where N is the rank (number of dimensions) of the shape determined according to Broadcasting". Meaning, the broadcasted shape, not the pre-broadcasted shape.

@kgryte
Copy link
Contributor

kgryte commented Nov 6, 2023

@asmeurer I'm having some difficulty following your OP. Can you suggest specifically how you'd like to see the specification revised? If it would be easier, you can submit a PR with the proposed updated guidance.

asmeurer added a commit to asmeurer/array-api that referenced this issue Feb 2, 2024
Nonnegative axes and negative axes less than the smaller of the two arrays are
unspecified.

This is because it is ambiguous in these cases whether the
dimension should refer to the axis before or after broadcasting. Preciously,
the spec stated it should refer to the dimension before broadcasting, but this
deviates from NumPy gufunc behavior, and results in ambiguous and confusing
situations, where, for instance, the result of a the function is different
when the inputs are manually broadcasted together.

Also clean up some of the cross text a little bit since the computed dimension
must be exactly size 3.

Fixes data-apis#724
Fixes data-apis#617

See the discussion in those issues for more details.
@asmeurer
Copy link
Member Author

asmeurer commented Feb 2, 2024

Fix at #740.

kgryte pushed a commit that referenced this issue Feb 13, 2024
This commit updates specification guidance in `vecdot` and `cross` to no longer explicitly support positive `axis` kwarg values. Previous specification guidance conflicts with NumPy gufuncs and restricting to negative integers removes ambiguity in determining over which axis to perform computation. This commit uses `should`, not `must`, to allow conforming libraries to support nonnegative `axis` values for backward compatibility.

Closes: #724
Closes: #617
PR-URL: 	#740
Reviewed-by: Athan Reines <kgryte@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Narrative Content Narrative documentation content. topic: Linear Algebra Linear algebra.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants