Skip to content

Commit

Permalink
Template {Broadcast,Unbroadcast}Array on element type/tag
Browse files Browse the repository at this point in the history
Previously, these functions were entirely type erased and only supported
Shared data pointers.

PiperOrigin-RevId: 658149134
Change-Id: I07180f1c076233dc87323584ee0c68a674ebf567
  • Loading branch information
jbms authored and copybara-github committed Jul 31, 2024
1 parent 7cb76bb commit bc6d6a1
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 150 deletions.
3 changes: 3 additions & 0 deletions tensorstore/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,11 @@ tensorstore_cc_library(
"//tensorstore/internal:type_traits",
"//tensorstore/util:constant_vector",
"//tensorstore/util:extents",
"//tensorstore/util:result",
"//tensorstore/util:span",
"//tensorstore/util:status",
"//tensorstore/util:str_cat",
"@com_google_absl//absl/status",
],
)

Expand Down
95 changes: 12 additions & 83 deletions tensorstore/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@
#include "tensorstore/internal/unaligned_data_type_functions.h"
#include "tensorstore/rank.h"
#include "tensorstore/serialization/serialization.h"
#include "tensorstore/serialization/span.h"
#include "tensorstore/serialization/span.h" // IWYU pragma: keep
#include "tensorstore/strided_layout.h"
#include "tensorstore/util/dimension_set.h"
#include "tensorstore/util/element_pointer.h"
#include "tensorstore/util/internal/iterate_impl.h"
#include "tensorstore/util/iterate.h"
#include "tensorstore/util/result.h"
#include "tensorstore/util/span.h"
#include "tensorstore/util/status.h"
#include "tensorstore/util/str_cat.h"

namespace tensorstore {
Expand Down Expand Up @@ -241,71 +240,6 @@ void PrintToOstream(
}
} // namespace internal_array

absl::Status ValidateShapeBroadcast(span<const Index> source_shape,
span<const Index> target_shape) {
for (DimensionIndex source_dim = 0; source_dim < source_shape.size();
++source_dim) {
const Index source_size = source_shape[source_dim];
if (source_size == 1) continue;
const DimensionIndex target_dim =
target_shape.size() - source_shape.size() + source_dim;
if (target_dim < 0 || target_shape[target_dim] != source_size) {
return absl::InvalidArgumentError(
tensorstore::StrCat("Cannot broadcast array of shape ", source_shape,
" to target shape ", target_shape));
}
}
return absl::OkStatus();
}

absl::Status BroadcastStridedLayout(StridedLayoutView<> source,
span<const Index> target_shape,
Index* target_byte_strides) {
TENSORSTORE_RETURN_IF_ERROR(
ValidateShapeBroadcast(source.shape(), target_shape));
SharedArray<const void> target;
for (DimensionIndex target_dim = 0; target_dim < target_shape.size();
++target_dim) {
const DimensionIndex source_dim =
target_dim + source.rank() - target_shape.size();
target_byte_strides[target_dim] =
(source_dim < 0 || source.shape()[source_dim] == 1)
? 0
: source.byte_strides()[source_dim];
}
return absl::OkStatus();
}

Result<SharedArray<const void>> BroadcastArray(
SharedArrayView<const void> source, span<const Index> target_shape) {
SharedArray<const void> target;
target.layout().set_rank(target_shape.size());
TENSORSTORE_RETURN_IF_ERROR(BroadcastStridedLayout(
source.layout(), target_shape, target.byte_strides().data()));
target.element_pointer() = std::move(source.element_pointer());
std::copy(target_shape.begin(), target_shape.end(), target.shape().begin());
return target;
}

Result<SharedOffsetArray<const void>> BroadcastArray(
SharedOffsetArrayView<const void> source, BoxView<> target_domain) {
SharedOffsetArray<const void> target;
target.layout().set_rank(target_domain.rank());
TENSORSTORE_RETURN_IF_ERROR(BroadcastStridedLayout(
StridedLayoutView<>(source.shape(), source.byte_strides()),
target_domain.shape(), target.byte_strides().data()));
std::copy_n(target_domain.origin().begin(), target_domain.rank(),
target.origin().begin());
std::copy_n(target_domain.shape().begin(), target_domain.rank(),
target.shape().begin());
target.element_pointer() =
AddByteOffset(std::move(source.element_pointer()),
internal::wrap_on_overflow::Subtract(
source.layout().origin_byte_offset(),
target.layout().origin_byte_offset()));
return target;
}

namespace internal_array {
void UnbroadcastStridedLayout(StridedLayoutView<> layout,
span<Index> unbroadcast_shape,
Expand All @@ -324,30 +258,25 @@ void UnbroadcastStridedLayout(StridedLayoutView<> layout,
unbroadcast_byte_strides[i] = byte_stride;
}
}
} // namespace internal_array

SharedArray<const void> UnbroadcastArray(
SharedOffsetArrayView<const void> source) {
void UnbroadcastStridedLayout(StridedLayoutView<> layout,
StridedLayout<>& unbroadcast_layout) {
DimensionIndex new_rank = 0;
for (DimensionIndex orig_dim = source.rank() - 1; orig_dim >= 0; --orig_dim) {
if (source.shape()[orig_dim] != 1 && source.byte_strides()[orig_dim] != 0) {
new_rank = source.rank() - orig_dim;
for (DimensionIndex orig_dim = layout.rank() - 1; orig_dim >= 0; --orig_dim) {
if (layout.shape()[orig_dim] != 1 && layout.byte_strides()[orig_dim] != 0) {
new_rank = layout.rank() - orig_dim;
}
}

SharedArray<const void> new_array;
new_array.layout().set_rank(new_rank);
unbroadcast_layout.set_rank(new_rank);
internal_array::UnbroadcastStridedLayout(
StridedLayoutView<>(
new_rank, source.shape().data() + source.rank() - new_rank,
source.byte_strides().data() + source.rank() - new_rank),
new_array.shape(), new_array.byte_strides());
new_array.element_pointer() =
AddByteOffset(std::move(source.element_pointer()),
source.layout().origin_byte_offset());
return new_array;
new_rank, layout.shape().data() + layout.rank() - new_rank,
layout.byte_strides().data() + layout.rank() - new_rank),
unbroadcast_layout.shape(), unbroadcast_layout.byte_strides());
}

} // namespace internal_array

namespace internal_array {

bool EncodeArray(serialization::EncodeSink& sink,
Expand Down
107 changes: 46 additions & 61 deletions tensorstore/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -1927,59 +1927,6 @@ bool AreArraysIdenticallyEqual(
return AreArraysEqual(a, b, EqualityComparisonKind::identical);
}

/// Validates that `source_shape` can be broadcast to `target_shape`.
///
/// A `source_shape` can be broadcast to a `target_shape` if, starting from the
/// trailing (highest index) dimensions, the size in `source_shape` is either
/// `1` or equal to the size in `target_shape`. Any additional leading
/// dimensions of `source_shape` that don't correspond to a dimension of
/// `target_shape` must be `1`. There are no restrictions on additional leading
/// dimensions of `target_shape` that don't correspond to a dimension of
/// `source_shape`.
///
/// For example:
///
/// [VALID]
/// source_shape: 5
/// target_shape: 4, 5
///
/// [VALID]
/// source_shape: 4, 1
/// target_shape: 4, 5
///
/// [VALID]
/// source_shape: 1, 1, 5
/// target_shape: 4, 5
///
/// [INVALID]
/// source_shape: 2, 5
/// target_shape: 4, 5
///
/// [INVALID]
/// source_shape: 2, 5
/// target_shape: 5
///
/// \returns `absl::OkStatus()` if the shapes are compatible.
/// \error `absl::StatusCode::kInvalidArgument` if the shapes are not
/// compatible.
/// \relates Array
/// \membergroup Broadcasting
absl::Status ValidateShapeBroadcast(span<const Index> source_shape,
span<const Index> target_shape);

/// Broadcasts `source` to `target_shape`.
///
/// \param source Source layout to broadcast.
/// \param target_shape Target shape to which `source` will be broadcast.
/// \param target_byte_strides Pointer to array of length `target_shape.size()`.
/// \error `absl::StatusCode::kInvalidArgument` if the shapes are not
/// compatible.
/// \relates Array
/// \membergroup Broadcasting
absl::Status BroadcastStridedLayout(StridedLayoutView<> source,
span<const Index> target_shape,
Index* target_byte_strides);

/// Broadcasts `source` to `target_shape`.
///
/// For example::
Expand All @@ -2002,10 +1949,29 @@ absl::Status BroadcastStridedLayout(StridedLayoutView<> source,
/// compatible.
/// \relates Array
/// \membergroup Broadcasting
Result<SharedArray<const void>> BroadcastArray(
SharedArrayView<const void> source, span<const Index> target_shape);
Result<SharedOffsetArray<const void>> BroadcastArray(
SharedOffsetArrayView<const void> source, BoxView<> target_domain);
template <typename ElementTag, DimensionIndex Rank, ContainerKind CKind>
Result<Array<ElementTag, dynamic_rank, zero_origin>> BroadcastArray(
const Array<ElementTag, Rank, zero_origin, CKind>& source,
tensorstore::span<const Index> target_shape) {
Array<ElementTag> target;
TENSORSTORE_RETURN_IF_ERROR(
BroadcastStridedLayout(source.layout(), target_shape, target.layout()));
target.element_pointer() = source.element_pointer();
return target;
}
template <typename ElementTag, DimensionIndex Rank, ArrayOriginKind OriginKind,
ContainerKind CKind>
Result<Array<ElementTag, dynamic_rank, offset_origin>> BroadcastArray(
const Array<ElementTag, Rank, OriginKind, CKind>& source,
BoxView<> target_domain) {
Array<ElementTag, dynamic_rank, offset_origin> target;
TENSORSTORE_ASSIGN_OR_RETURN(
Index byte_offset,
BroadcastStridedLayout(source.layout(), target_domain, target.layout()));
target.element_pointer() =
AddByteOffset(source.element_pointer(), byte_offset);
return target;
}

namespace internal_array {
/// Converts zero-stride dimensions (with non-zero size) to have an extent of 1,
Expand All @@ -2016,9 +1982,13 @@ namespace internal_array {
/// with the unbroadcast shape.
/// \param unbroadcast_byte_strides[out] Array of size `layout.rank()` to be
/// filled with the unbroadcast byte strides.
/// \param unbroadcast_layout Set to unbroadcasted layout.
void UnbroadcastStridedLayout(
StridedLayoutView<> layout, tensorstore::span<Index> unbroadcast_shape,
tensorstore::span<Index> unbroadcast_byte_strides);
void UnbroadcastStridedLayout(StridedLayoutView<> layout,
span<Index> unbroadcast_shape,
span<Index> unbroadcast_byte_strides);
StridedLayout<>& unbroadcast_layout);

} // namespace internal_array

/// Converts zero-stride dimensions (with non-zero size) to have an extent of 1,
Expand All @@ -2033,8 +2003,23 @@ void UnbroadcastStridedLayout(StridedLayoutView<> layout,
///
/// \relates Array
/// \membergroup Broadcasting
SharedArray<const void> UnbroadcastArray(
SharedOffsetArrayView<const void> source);
template <typename ElementTag, DimensionIndex Rank, ArrayOriginKind OriginKind,
ContainerKind CKind>
Array<ElementTag> UnbroadcastArray(
const Array<ElementTag, Rank, OriginKind, CKind>& source) {
Array<ElementTag> target;
internal_array::UnbroadcastStridedLayout(
StridedLayout<>(source.rank(), source.shape().data(),
source.byte_strides().data()),
target.layout());
if constexpr (OriginKind == offset_origin) {
target.element_pointer() = AddByteOffset(
source.element_pointer(), source.layout().origin_byte_offset());
} else {
target.element_pointer() = source.element_pointer();
}
return target;
}

/// Converts zero-stride dimensions (with non-zero size) to have an extent of 1,
/// and translates the origin to 0.
Expand Down
22 changes: 16 additions & 6 deletions tensorstore/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1835,9 +1835,13 @@ TEST(BroadcastStridedLayoutTest, Basic) {
}

TEST(BroadcastArrayTest, ZeroOrigin) {
EXPECT_THAT(
BroadcastArray(MakeArray<int>({1, 2, 3}), span<const Index>({2, 3})),
MakeArray<int>({{1, 2, 3}, {1, 2, 3}}));
{
TENSORSTORE_ASSERT_OK_AND_ASSIGN(
auto b,
BroadcastArray(MakeArray<int>({1, 2, 3}), span<const Index>({2, 3})));
static_assert(std::is_same_v<decltype(b), SharedArray<int>>);
EXPECT_THAT(b, MakeArray<int>({{1, 2, 3}, {1, 2, 3}}));
}
EXPECT_THAT(BroadcastArray(MakeArray<int>({{1}, {2}, {3}}),
span<const Index>({3, 2})),
MakeArray<int>({{1, 1}, {2, 2}, {3, 3}}));
Expand All @@ -1850,9 +1854,14 @@ TEST(BroadcastArrayTest, ZeroOrigin) {
}

TEST(BroadcastArrayTest, OffsetOrigin) {
EXPECT_THAT(BroadcastArray(MakeOffsetArray<int>({3}, {1, 2, 3}),
BoxView<>({1, 2}, {2, 3})),
MakeOffsetArray<int>({1, 2}, {{1, 2, 3}, {1, 2, 3}}));
{
TENSORSTORE_ASSERT_OK_AND_ASSIGN(
auto b, BroadcastArray(MakeOffsetArray<int>({3}, {1, 2, 3}),
BoxView<>({1, 2}, {2, 3})));
static_assert(
std::is_same_v<decltype(b), tensorstore::SharedOffsetArray<int>>);
EXPECT_THAT(b, MakeOffsetArray<int>({1, 2}, {{1, 2, 3}, {1, 2, 3}}));
}
EXPECT_THAT(BroadcastArray(MakeOffsetArray<int>({3, 4}, {{1}, {2}, {3}}),
BoxView<>({1, 2}, {3, 2})),
MakeOffsetArray<int>({1, 2}, {{1, 1}, {2, 2}, {3, 3}}));
Expand All @@ -1870,6 +1879,7 @@ TEST(UnbroadcastArrayTest, Basic) {
auto broadcast_array,
BroadcastArray(orig_array, BoxView<>({1, 2, 3, 4}, {2, 3, 2, 2})));
auto unbroadcast_array = UnbroadcastArray(broadcast_array);
static_assert(std::is_same_v<decltype(unbroadcast_array), SharedArray<int>>);
auto unbroadcast_array2 = UnbroadcastArray(unbroadcast_array);
EXPECT_EQ(orig_array, unbroadcast_array);
EXPECT_EQ(orig_array.pointer(), unbroadcast_array.pointer());
Expand Down
Loading

0 comments on commit bc6d6a1

Please sign in to comment.