From d047712cff150410a20ae4da6b494ded82d110bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Mon, 17 Jul 2023 22:28:43 +0200 Subject: [PATCH] Add convenience extension methods for scalars and a few other enhancements --- core/src/main/scala/torch/Tensor.scala | 71 +++++++++++++++++-- core/src/main/scala/torch/Types.scala | 3 + .../torch/internal/NativeConverters.scala | 16 ++--- .../main/scala/torch/ops/PointwiseOps.scala | 31 ++++---- 4 files changed, 97 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/torch/Tensor.scala b/core/src/main/scala/torch/Tensor.scala index d7ea24a6..7ad0c3bb 100644 --- a/core/src/main/scala/torch/Tensor.scala +++ b/core/src/main/scala/torch/Tensor.scala @@ -52,8 +52,10 @@ import spire.math.{Complex, UByte} import scala.reflect.Typeable import internal.NativeConverters -import torch.Device.CPU -import torch.Layout.Strided +import internal.NativeConverters.toArray +import internal.LoadCusolver +import Device.CPU +import Layout.Strided import org.bytedeco.pytorch.ByteArrayRef import org.bytedeco.pytorch.ShortArrayRef import org.bytedeco.pytorch.BoolArrayRef @@ -64,7 +66,8 @@ import org.bytedeco.pytorch.DoubleArrayRef import org.bytedeco.pytorch.EllipsisIndexType import org.bytedeco.pytorch.SymInt import org.bytedeco.pytorch.SymIntOptional -import internal.LoadCusolver +import org.bytedeco.pytorch.ScalarTypeOptional +import scala.annotation.implicitNotFound case class TensorTuple[D <: DType]( values: Tensor[D], @@ -306,6 +309,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto native.flatten(startDim, endDim) ) + def float: Tensor[Float32] = to(dtype = float32) + /** This function returns an undefined tensor by default and returns a defined tensor the first * time a call to backward() computes gradients for this Tensor. The attribute will then contain * the gradients computed and future calls to backward() will accumulate (add) gradients into it. @@ -406,6 +411,26 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto def mean: Tensor[D] = Tensor(native.mean()) + /** @see + * [[torch.mean]] + */ + def mean[D2 <: DType | Derive]( + dim: Int | Seq[Int] = Seq.empty, + keepdim: Boolean = false, + dtype: D2 = derive + ): Tensor[DTypeOrDeriveFromTensor[D, D2]] = + val derivedDType = dtype match + case _: Derive => this.dtype + case d: DType => d + Tensor( + torchNative.mean( + native, + dim.toArray, + keepdim, + new ScalarTypeOptional(derivedDType.toScalarType) + ) + ) + def min(): Tensor[Int64] = Tensor[Int64](native.min()) def minimum[D2 <: DType](other: Tensor[D2]): Tensor[Promoted[D, D2]] = @@ -425,7 +450,25 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto def permute(dims: Int*): Tensor[D] = Tensor(native.permute(dims.map(_.toLong)*)) - def pow(exponent: Double): Tensor[D] = Tensor(native.pow(Scalar.apply(exponent))) + /** @see [[torch.pow]] */ + def pow[D2 <: DType](exponent: Tensor[D2])(using + @implicitNotFound(""""pow" not implemented for bool""") + ev1: Promoted[D, D2] NotEqual Bool, + @implicitNotFound(""""pow" not implemented for complex32""") + ev2: Promoted[D, D2] NotEqual Complex32 + ): Tensor[Promoted[D, D2]] = Tensor( + native.pow(exponent.native) + ) + + /** @see [[torch.pow]] */ + def pow[S <: ScalaType](exponent: S)(using + @implicitNotFound(""""pow" not implemented for bool""") + ev1: Promoted[D, ScalaToDType[S]] NotEqual Bool, + @implicitNotFound(""""pow" not implemented for complex32""") + ev2: Promoted[D, ScalaToDType[S]] NotEqual Complex32 + ): Tensor[Promoted[D, ScalaToDType[S]]] = Tensor( + native.pow(exponent.toScalar) + ) def prod[D <: DType](dtype: D = this.dtype) = Tensor(native.prod()) @@ -516,6 +559,13 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto def takeAlongDim(indices: Tensor[Int64], dim: Int) = native.take_along_dim(indices.native, toOptional(dim)) + def to(device: Device | Option[Device]): Tensor[D] = device match + case dev: Device => to(dev, this.dtype) + case Some(dev) => to(dev, this.dtype) + case None => this + + def to[U <: DType](dtype: U): Tensor[U] = to(this.device, dtype) + // TODO support memory_format /** Performs Tensor dtype and/or device conversion. */ def to[U <: DType]( @@ -849,3 +899,16 @@ object Tensor: ) ) case _ => throw new IllegalArgumentException("Unsupported type") + +/** Scalar/Tensor extensions to allow tensor operations directly on scalars */ +extension [S <: ScalaType](s: S) + def +[D <: DType](t: Tensor[D]): Tensor[Promoted[D, ScalaToDType[S]]] = t.add(s) + def -[D <: DType](t: Tensor[D]): Tensor[Promoted[D, ScalaToDType[S]]] = t.sub(s) + def *[D <: DType](t: Tensor[D]): Tensor[Promoted[D, ScalaToDType[S]]] = t.mul(s) + def /[D <: DType](t: Tensor[D]): Tensor[Div[D, ScalaToDType[S]]] = t.div(s) + def **[D <: DType](t: Tensor[D])(using + @implicitNotFound(""""pow" not implemented for bool""") + ev1: Promoted[D, ScalaToDType[S]] NotEqual Bool, + @implicitNotFound(""""pow" not implemented for complex32""") + ev2: Promoted[D, ScalaToDType[S]] NotEqual Complex32 + ): Tensor[Promoted[D, ScalaToDType[S]]] = t.pow(s) diff --git a/core/src/main/scala/torch/Types.scala b/core/src/main/scala/torch/Types.scala index f15d923f..3449201c 100644 --- a/core/src/main/scala/torch/Types.scala +++ b/core/src/main/scala/torch/Types.scala @@ -54,3 +54,6 @@ type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN /* Evidence used in operations where at least one Float or Complex is required */ type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< (FloatNN | ComplexNN) | B <:< (FloatNN | ComplexNN) + +/* Evidence that two dtypes are not the same */ +type NotEqual[D <: DType, D2 <: DType] = NotGiven[D =:= D2] diff --git a/core/src/main/scala/torch/internal/NativeConverters.scala b/core/src/main/scala/torch/internal/NativeConverters.scala index 996ee76d..e18adee3 100644 --- a/core/src/main/scala/torch/internal/NativeConverters.scala +++ b/core/src/main/scala/torch/internal/NativeConverters.scala @@ -83,14 +83,14 @@ private[torch] object NativeConverters: extension (x: ScalaType) def toScalar: pytorch.Scalar = x match - case x: Boolean => pytorch.Scalar(if x then 1: Byte else 0: Byte) - case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item() - case x: Byte => pytorch.Scalar(x) - case x: Short => pytorch.Scalar(x) - case x: Int => pytorch.Scalar(x) - case x: Long => pytorch.Scalar(x) - case x: Float => pytorch.Scalar(x) - case x: Double => pytorch.Scalar(x) + case x: Boolean => pytorch.AbstractTensor.create(x).item() + case x: UByte => pytorch.AbstractTensor.create(x.toInt).to(uint8.toScalarType).item() + case x: Byte => pytorch.Scalar(x) + case x: Short => pytorch.Scalar(x) + case x: Int => pytorch.Scalar(x) + case x: Long => pytorch.Scalar(x) + case x: Float => pytorch.Scalar(x) + case x: Double => pytorch.Scalar(x) case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item() case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item() diff --git a/core/src/main/scala/torch/ops/PointwiseOps.scala b/core/src/main/scala/torch/ops/PointwiseOps.scala index ec3e2c0e..643472f1 100644 --- a/core/src/main/scala/torch/ops/PointwiseOps.scala +++ b/core/src/main/scala/torch/ops/PointwiseOps.scala @@ -19,6 +19,7 @@ package ops import internal.NativeConverters.* import org.bytedeco.pytorch.global.torch as torchNative +import scala.annotation.implicitNotFound /** Pointwise Ops * @@ -673,22 +674,28 @@ private[torch] trait PointwiseOps { * * @group pointwise_ops */ - def pow[D <: DType, D2 <: DType]( - input: Tensor[D], - exponent: Tensor[D2] - )(using OnlyOneBool[D, D2]): Tensor[Promoted[D, D2]] = + def pow[D <: DType, D2 <: DType](input: Tensor[D], exponent: Tensor[D2])(using + @implicitNotFound(""""pow" not implemented for bool""") + ev1: Promoted[D, D2] NotEqual Bool, + @implicitNotFound(""""pow" not implemented for complex32""") + ev2: Promoted[D, D2] NotEqual Complex32 + ): Tensor[Promoted[D, D2]] = Tensor(torchNative.pow(input.native, exponent.native)) - def pow[D <: DType, S <: ScalaType]( - input: Tensor[D], - exponent: S - )(using OnlyOneBool[D, ScalaToDType[S]]): Tensor[Promoted[D, ScalaToDType[S]]] = + def pow[D <: DType, S <: ScalaType](input: Tensor[D], exponent: S)(using + @implicitNotFound(""""pow" not implemented for bool""") + ev1: Promoted[D, ScalaToDType[S]] NotEqual Bool, + @implicitNotFound(""""pow" not implemented for complex32""") + ev2: Promoted[D, ScalaToDType[S]] NotEqual Complex32 + ): Tensor[Promoted[D, ScalaToDType[S]]] = Tensor(torchNative.pow(input.native, toScalar(exponent))) - def pow[S <: ScalaType, D <: DType]( - input: S, - exponent: Tensor[D] - )(using OnlyOneBool[ScalaToDType[S], D]): Tensor[Promoted[ScalaToDType[S], D]] = + def pow[S <: ScalaType, D <: DType](input: S, exponent: Tensor[D])(using + @implicitNotFound(""""pow" not implemented for bool""") + ev1: Promoted[D, ScalaToDType[S]] NotEqual Bool, + @implicitNotFound(""""pow" not implemented for complex32""") + ev2: Promoted[D, ScalaToDType[S]] NotEqual Complex32 + ): Tensor[Promoted[ScalaToDType[S], D]] = Tensor(torchNative.pow(toScalar(input), exponent.native)) // TODO Implement creation of QInts