Skip to content

Commit

Permalink
Add torch.argsort
Browse files Browse the repository at this point in the history
  • Loading branch information
davoclavo committed Jul 1, 2023
1 parent 6e69cbd commit c1ff1e3
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
46 changes: 45 additions & 1 deletion core/src/main/scala/torch/ops/ComparisonOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,50 @@ private[torch] trait ComparisonOps {
rtol: Double = 1e-05,
atol: Double = 1e-08,
equalNan: Boolean = false
) =
): Boolean =
torchNative.allclose(input.native, other.native, rtol, atol, equalNan)

/** Returns the indices that sort a tensor along a given dimension in ascending order by value.
*
* This is the second value returned by `torch.sort`. See its documentation for the exact
* semantics of this method.
*
* If `stable` is `True` then the sorting routine becomes stable, preserving the order of
* equivalent elements. If `False`, the relative order of values which compare equal is not
* guaranteed. `True` is slower.
*
* Args: {input} dim (int, optional): the dimension to sort along descending (bool, optional):
* controls the sorting order (ascending or descending) stable (bool, optional): controls the
* relative order of equivalent elements
*
* Example:
*
* ```scala sc
* val a = torch.randn(Seq(4, 4))
* // tensor dtype=float32, shape=[4, 4], device=CPU
* // [[ 0.0785, 1.5267, -0.8521, 0.4065],
* // [ 0.1598, 0.0788, -0.0745, -1.2700],
* // [ 1.2208, 1.0722, -0.7064, 1.2564],
* // [ 0.0669, -0.2318, -0.8229, -0.9280]]
*
* torch.argsort(a, dim = 1)
* // tensor dtype=int64, shape=[4, 4], device=CPU
* // [[2, 0, 3, 1],
* // [3, 2, 1, 0],
* // [2, 1, 0, 3],
* // [3, 2, 1, 0]]
* ```
*
* @group comparison_ops
*/
def argsort[D <: RealNN](
input: Tensor[D],
dim: Int = -1,
descending: Boolean = false
// TODO implement stable, there are two boolean args in argsort and are not in order
// stable: Boolean = false
): Tensor[Int64] =
Tensor(
torchNative.argsort(input.native, dim.toLong, descending)
)
}
28 changes: 28 additions & 0 deletions core/src/test/scala/torch/ops/ComparisonOpsSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch

class ComparisonOpsSuite extends TensorCheckSuite {

testUnaryOp(
op = argsort(_),
opName = "argsort",
inputTensor = Tensor(Seq(1, 3, 2)),
expectedTensor = Tensor(Seq(0L, 2L, 1L))
)

}

0 comments on commit c1ff1e3

Please sign in to comment.