Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Indexing, Slicing, Joining, Mutating Ops #35

Merged
merged 6 commits into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 67 additions & 34 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto

def acos: Tensor[D] = Tensor(native.acos())

def adjoint: Tensor[D] = Tensor(native.adjoint())

/** Tests if all elements of this tensor evaluate to `true`. */
def all: Tensor[Bool] = Tensor(native.all())

Expand Down Expand Up @@ -285,6 +287,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
*/
def eq(other: Tensor[?]): Tensor[Bool] = Tensor(native.eq(other.native))

def ==(other: Tensor[?]): Tensor[Bool] = eq(other)

override def equals(that: Any): Boolean =
that match
case other: Tensor[?] if dtype == other.dtype => native.equal(other.native)
Expand All @@ -308,6 +312,14 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
*/
def grad: Tensor[D | Undefined] = Tensor(native.grad())

def ge(other: ScalaType): Tensor[Bool] = Tensor(native.ge(toScalar(other)))

def >=(other: ScalaType): Tensor[Bool] = ge(other)

def gt(other: ScalaType): Tensor[Bool] = Tensor(native.gt(toScalar(other)))

def >(other: ScalaType): Tensor[Bool] = gt(other)

def isContiguous: Boolean = native.is_contiguous()

def isCuda: Boolean = native.is_cuda()
Expand All @@ -318,6 +330,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto

def isNonzero: Boolean = native.is_nonzero()

def isConj: Boolean = native.is_conj()

// TODO override in subclasses instead?
def item: DTypeToScala[D] =
import ScalarType.*
Expand Down Expand Up @@ -354,6 +368,14 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto

def long: Tensor[Int64] = to(dtype = int64)

def le(other: ScalaType): Tensor[Bool] = Tensor(native.le(toScalar(other)))

def <=(other: ScalaType): Tensor[Bool] = le(other)

def lt(other: ScalaType): Tensor[Bool] = Tensor(native.lt(toScalar(other)))

def <(other: ScalaType): Tensor[Bool] = lt(other)

def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] =
Tensor[Promoted[D, D2]](native.matmul(u.native))

Expand Down Expand Up @@ -757,47 +779,58 @@ object Tensor:

/** Constructs a tensor with no autograd history (also known as a “leaf tensor”) by copying data.
*/
// TODO support multidimensional arrays as input
// TODO support arbitrary multidimensional arrays as input
// TODO support explicit dtype
def apply[U <: ScalaType: ClassTag](
data: Seq[U] | U,
data: U | Seq[U] | Seq[Seq[U]] | Seq[Seq[Seq[U]]],
layout: Layout = Strided,
device: Device = CPU,
requiresGrad: Boolean = false
): Tensor[ScalaToDType[U]] =
data match
case data: Seq[?] =>
val (pointer, inputDType) = data.toArray match
case bools: Array[Boolean] =>
(
{
val p = new BoolPointer(bools.length)
for ((b, i) <- bools.zipWithIndex) p.put(i, b)
p
},
bool
)
case bytes: Array[Byte] => (new BytePointer(ByteBuffer.wrap(bytes)), int8)
case shorts: Array[Short] => (new ShortPointer(ShortBuffer.wrap(shorts)), int16)
case ints: Array[Int] => (new IntPointer(IntBuffer.wrap(ints)), int32)
case longs: Array[Long] => (new LongPointer(LongBuffer.wrap(longs)), int64)
case floats: Array[Float] => (new FloatPointer(FloatBuffer.wrap(floats)), float32)
case doubles: Array[Double] => (new DoublePointer(DoubleBuffer.wrap(doubles)), float64)
case complexFloatArray(complexFloats) =>
(
new FloatPointer(
FloatBuffer.wrap(complexFloats.flatMap(c => Array(c.real, c.imag)))
),
complex64
)
case complexDoubleArray(complexDoubles) =>
(
new DoublePointer(
DoubleBuffer.wrap(complexDoubles.flatMap(c => Array(c.real, c.imag)))
),
complex128
)
case _ => throw new IllegalArgumentException(s"Unsupported sequence type")
case tripleSeq(data) =>
apply(data.flatten.flatten.asInstanceOf[Seq[U]], layout, device, requiresGrad)
.view(data.length, data.head.length, data.head.head.length)
case doubleSeq(data) =>
apply(data.flatten.asInstanceOf[Seq[U]], layout, device, requiresGrad)
.view(data.length, data.head.length)
case singleSeq(data) =>
val (pointer, inputDType) =
data.asInstanceOf[Seq[U]].toArray match
case bools: Array[Boolean] =>
(
{
val p = new BoolPointer(bools.length)
for ((b, i) <- bools.zipWithIndex) p.put(i, b)
p
},
bool
)
case bytes: Array[Byte] => (new BytePointer(ByteBuffer.wrap(bytes)), int8)
case shorts: Array[Short] => (new ShortPointer(ShortBuffer.wrap(shorts)), int16)
case ints: Array[Int] => (new IntPointer(IntBuffer.wrap(ints)), int32)
case longs: Array[Long] => (new LongPointer(LongBuffer.wrap(longs)), int64)
case floats: Array[Float] => (new FloatPointer(FloatBuffer.wrap(floats)), float32)
case doubles: Array[Double] => (new DoublePointer(DoubleBuffer.wrap(doubles)), float64)
case complexFloatArray(complexFloats) =>
(
new FloatPointer(
FloatBuffer.wrap(complexFloats.flatMap(c => Array(c.real, c.imag)))
),
complex64
)
case complexDoubleArray(complexDoubles) =>
(
new DoublePointer(
DoubleBuffer.wrap(complexDoubles.flatMap(c => Array(c.real, c.imag)))
),
complex128
)
case _ =>
throw new IllegalArgumentException(
s"Unsupported data type ${summon[ClassTag[U]].runtimeClass.getSimpleName}"
)

Tensor(
torchNative
.from_blob(
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala/torch/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ given iterableTypeable[T](using tt: Typeable[T]): Typeable[Array[T]] with
val complexDoubleArray = TypeCase[Array[Complex[Double]]]
val complexFloatArray = TypeCase[Array[Complex[Float]]]

/* TypeCase helpers to perform pattern matching on `Seq` higher kinded types */
val singleSeq = TypeCase[Seq[?]]
val doubleSeq = TypeCase[Seq[Seq[?]]]
val tripleSeq = TypeCase[Seq[Seq[Seq[?]]]]

/* Type helper to describe inputs that accept Tensor or Real scalars */
type TensorOrReal[D <: RealNN] = Tensor[D] | Real

Expand Down
46 changes: 45 additions & 1 deletion core/src/main/scala/torch/ops/ComparisonOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,50 @@ private[torch] trait ComparisonOps {
rtol: Double = 1e-05,
atol: Double = 1e-08,
equalNan: Boolean = false
) =
): Boolean =
torchNative.allclose(input.native, other.native, rtol, atol, equalNan)

/** Returns the indices that sort a tensor along a given dimension in ascending order by value.
*
* This is the second value returned by `torch.sort`. See its documentation for the exact
* semantics of this method.
*
* If `stable` is `True` then the sorting routine becomes stable, preserving the order of
* equivalent elements. If `False`, the relative order of values which compare equal is not
* guaranteed. `True` is slower.
*
* Args: {input} dim (int, optional): the dimension to sort along descending (bool, optional):
* controls the sorting order (ascending or descending) stable (bool, optional): controls the
* relative order of equivalent elements
*
* Example:
*
* ```scala sc
* val a = torch.randn(Seq(4, 4))
* // tensor dtype=float32, shape=[4, 4], device=CPU
* // [[ 0.0785, 1.5267, -0.8521, 0.4065],
* // [ 0.1598, 0.0788, -0.0745, -1.2700],
* // [ 1.2208, 1.0722, -0.7064, 1.2564],
* // [ 0.0669, -0.2318, -0.8229, -0.9280]]
*
* torch.argsort(a, dim = 1)
* // tensor dtype=int64, shape=[4, 4], device=CPU
* // [[2, 0, 3, 1],
* // [3, 2, 1, 0],
* // [2, 1, 0, 3],
* // [3, 2, 1, 0]]
* ```
*
* @group comparison_ops
*/
def argsort[D <: RealNN](
input: Tensor[D],
dim: Int = -1,
descending: Boolean = false
// TODO implement stable, there are two boolean args in argsort and are not in order
// stable: Boolean = false
): Tensor[Int64] =
Tensor(
torchNative.argsort(input.native, dim.toLong, descending)
)
}
Loading