Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
improve scatter_nd doc according to tf version
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Jul 27, 2018
1 parent c13ce5c commit 0683d0a
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,20 +573,22 @@ Examples::
.add_argument("indices", "NDArray-or-Symbol", "indices");

NNVM_REGISTER_OP(scatter_nd)
.describe(R"code(Scatters data into a new tensor according to indices.
.describe(R"code(Scatters `data` into a new tensor according to `indices`.
Given `data` with shape `(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1})` and indices with shape
`(M, Y_0, ..., Y_{K-1})`, the output will have shape `(X_0, X_1, ..., X_{N-1})`,
where `M <= N`. If `M == N`, data shape should simply be `(Y_0, ..., Y_{K-1})`.
`indices` is an integer tensor containing indices into a new tensor of shape `shape`.
The last dimension of `indices` can be at most the rank of `shape`:
The elements in output is defined as follows::
.. math::
output[indices[0, y_0, ..., y_{K-1}],
...,
indices[M-1, y_0, ..., y_{K-1}],
x_M, ..., x_{N-1}] = data[y_0, ..., y_{K-1}, x_M, ..., x_{N-1}]
indices.shape[-1] <= shape.rank
The last dimension of `indices` corresponds to indices into elements (if `indices.shape[-1] =
shape.rank`) or slices (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]`
of `shape`. `data` is a tensor with shape:
all other entries in output are 0.
.. math::
indices.shape[-1] + shape[indices.shape[-1]:]
.. warning::
Expand All @@ -601,6 +603,17 @@ Examples::
shape = (2, 2)
scatter_nd(data, indices, shape) = [[0, 0], [2, 3]]
data = [[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]]
indices = [[0], [2]]
shape = (4, 4, 4)
scatter_nd(data, indices, shape) = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
)code")
.set_num_outputs(1)
.set_num_inputs(2)
Expand Down

0 comments on commit 0683d0a

Please sign in to comment.