diff --git a/cubed/array_api/linear_algebra_functions.py b/cubed/array_api/linear_algebra_functions.py index 538fb2a0..272e06b4 100644 --- a/cubed/array_api/linear_algebra_functions.py +++ b/cubed/array_api/linear_algebra_functions.py @@ -3,7 +3,11 @@ from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import _numeric_dtypes -from cubed.array_api.manipulation_functions import expand_dims +from cubed.array_api.manipulation_functions import ( + broadcast_arrays, + expand_dims, + moveaxis, +) from cubed.backend_array_api import namespace as nxp from cubed.core import blockwise, reduction, squeeze @@ -158,12 +162,21 @@ def _tensordot(a, b, axes): def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None): + # based on the implementation in array-api-compat if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in vecdot") - return tensordot( - x1, - x2, - axes=((axis,), (axis,)), + + if x1.shape[axis] != x2.shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_ = moveaxis(x1, axis, -1) + x2_ = moveaxis(x2, axis, -1) + x1_, x2_ = broadcast_arrays(x1_, x2_) + + res = matmul( + x1_[..., None, :], + x2_[..., None], use_new_impl=use_new_impl, split_every=split_every, ) + return res[..., 0, 0]