Skip to content

Commit

Permalink
Add convenience extension methods for scalars and a few other enhance…
Browse files Browse the repository at this point in the history
…ments
  • Loading branch information
sbrunk authored and davoclavo committed Jul 27, 2023
1 parent 9a6409b commit d047712
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 24 deletions.
71 changes: 67 additions & 4 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ import spire.math.{Complex, UByte}

import scala.reflect.Typeable
import internal.NativeConverters
import torch.Device.CPU
import torch.Layout.Strided
import internal.NativeConverters.toArray
import internal.LoadCusolver
import Device.CPU
import Layout.Strided
import org.bytedeco.pytorch.ByteArrayRef
import org.bytedeco.pytorch.ShortArrayRef
import org.bytedeco.pytorch.BoolArrayRef
Expand All @@ -64,7 +66,8 @@ import org.bytedeco.pytorch.DoubleArrayRef
import org.bytedeco.pytorch.EllipsisIndexType
import org.bytedeco.pytorch.SymInt
import org.bytedeco.pytorch.SymIntOptional
import internal.LoadCusolver
import org.bytedeco.pytorch.ScalarTypeOptional
import scala.annotation.implicitNotFound

case class TensorTuple[D <: DType](
values: Tensor[D],
Expand Down Expand Up @@ -306,6 +309,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
native.flatten(startDim, endDim)
)

def float: Tensor[Float32] = to(dtype = float32)

/** This function returns an undefined tensor by default and returns a defined tensor the first
* time a call to backward() computes gradients for this Tensor. The attribute will then contain
* the gradients computed and future calls to backward() will accumulate (add) gradients into it.
Expand Down Expand Up @@ -406,6 +411,26 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto

def mean: Tensor[D] = Tensor(native.mean())

/** @see
* [[torch.mean]]
*/
def mean[D2 <: DType | Derive](
dim: Int | Seq[Int] = Seq.empty,
keepdim: Boolean = false,
dtype: D2 = derive
): Tensor[DTypeOrDeriveFromTensor[D, D2]] =
val derivedDType = dtype match
case _: Derive => this.dtype
case d: DType => d
Tensor(
torchNative.mean(
native,
dim.toArray,
keepdim,
new ScalarTypeOptional(derivedDType.toScalarType)
)
)

def min(): Tensor[Int64] = Tensor[Int64](native.min())

def minimum[D2 <: DType](other: Tensor[D2]): Tensor[Promoted[D, D2]] =
Expand All @@ -425,7 +450,25 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto

def permute(dims: Int*): Tensor[D] = Tensor(native.permute(dims.map(_.toLong)*))

def pow(exponent: Double): Tensor[D] = Tensor(native.pow(Scalar.apply(exponent)))
/** @see [[torch.pow]] */
def pow[D2 <: DType](exponent: Tensor[D2])(using
@implicitNotFound(""""pow" not implemented for bool""")
ev1: Promoted[D, D2] NotEqual Bool,
@implicitNotFound(""""pow" not implemented for complex32""")
ev2: Promoted[D, D2] NotEqual Complex32
): Tensor[Promoted[D, D2]] = Tensor(
native.pow(exponent.native)
)

/** @see [[torch.pow]] */
def pow[S <: ScalaType](exponent: S)(using
@implicitNotFound(""""pow" not implemented for bool""")
ev1: Promoted[D, ScalaToDType[S]] NotEqual Bool,
@implicitNotFound(""""pow" not implemented for complex32""")
ev2: Promoted[D, ScalaToDType[S]] NotEqual Complex32
): Tensor[Promoted[D, ScalaToDType[S]]] = Tensor(
native.pow(exponent.toScalar)
)

def prod[D <: DType](dtype: D = this.dtype) = Tensor(native.prod())

Expand Down Expand Up @@ -516,6 +559,13 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
def takeAlongDim(indices: Tensor[Int64], dim: Int) =
native.take_along_dim(indices.native, toOptional(dim))

def to(device: Device | Option[Device]): Tensor[D] = device match
case dev: Device => to(dev, this.dtype)
case Some(dev) => to(dev, this.dtype)
case None => this

def to[U <: DType](dtype: U): Tensor[U] = to(this.device, dtype)

// TODO support memory_format
/** Performs Tensor dtype and/or device conversion. */
def to[U <: DType](
Expand Down Expand Up @@ -849,3 +899,16 @@ object Tensor:
)
)
case _ => throw new IllegalArgumentException("Unsupported type")

/** Scalar/Tensor extensions to allow tensor operations directly on scalars */
extension [S <: ScalaType](s: S)
def +[D <: DType](t: Tensor[D]): Tensor[Promoted[D, ScalaToDType[S]]] = t.add(s)
def -[D <: DType](t: Tensor[D]): Tensor[Promoted[D, ScalaToDType[S]]] = t.sub(s)
def *[D <: DType](t: Tensor[D]): Tensor[Promoted[D, ScalaToDType[S]]] = t.mul(s)
def /[D <: DType](t: Tensor[D]): Tensor[Div[D, ScalaToDType[S]]] = t.div(s)
def **[D <: DType](t: Tensor[D])(using
@implicitNotFound(""""pow" not implemented for bool""")
ev1: Promoted[D, ScalaToDType[S]] NotEqual Bool,
@implicitNotFound(""""pow" not implemented for complex32""")
ev2: Promoted[D, ScalaToDType[S]] NotEqual Complex32
): Tensor[Promoted[D, ScalaToDType[S]]] = t.pow(s)
3 changes: 3 additions & 0 deletions core/src/main/scala/torch/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN
/* Evidence used in operations where at least one Float or Complex is required */
type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< (FloatNN | ComplexNN) |
B <:< (FloatNN | ComplexNN)

/* Evidence that two dtypes are not the same */
type NotEqual[D <: DType, D2 <: DType] = NotGiven[D =:= D2]
16 changes: 8 additions & 8 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ private[torch] object NativeConverters:

extension (x: ScalaType)
def toScalar: pytorch.Scalar = x match
case x: Boolean => pytorch.Scalar(if x then 1: Byte else 0: Byte)
case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item()
case x: Byte => pytorch.Scalar(x)
case x: Short => pytorch.Scalar(x)
case x: Int => pytorch.Scalar(x)
case x: Long => pytorch.Scalar(x)
case x: Float => pytorch.Scalar(x)
case x: Double => pytorch.Scalar(x)
case x: Boolean => pytorch.AbstractTensor.create(x).item()
case x: UByte => pytorch.AbstractTensor.create(x.toInt).to(uint8.toScalarType).item()
case x: Byte => pytorch.Scalar(x)
case x: Short => pytorch.Scalar(x)
case x: Int => pytorch.Scalar(x)
case x: Long => pytorch.Scalar(x)
case x: Float => pytorch.Scalar(x)
case x: Double => pytorch.Scalar(x)
case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item()
case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item()

Expand Down
31 changes: 19 additions & 12 deletions core/src/main/scala/torch/ops/PointwiseOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package ops

import internal.NativeConverters.*
import org.bytedeco.pytorch.global.torch as torchNative
import scala.annotation.implicitNotFound

/** Pointwise Ops
*
Expand Down Expand Up @@ -673,22 +674,28 @@ private[torch] trait PointwiseOps {
*
* @group pointwise_ops
*/
def pow[D <: DType, D2 <: DType](
input: Tensor[D],
exponent: Tensor[D2]
)(using OnlyOneBool[D, D2]): Tensor[Promoted[D, D2]] =
def pow[D <: DType, D2 <: DType](input: Tensor[D], exponent: Tensor[D2])(using
@implicitNotFound(""""pow" not implemented for bool""")
ev1: Promoted[D, D2] NotEqual Bool,
@implicitNotFound(""""pow" not implemented for complex32""")
ev2: Promoted[D, D2] NotEqual Complex32
): Tensor[Promoted[D, D2]] =
Tensor(torchNative.pow(input.native, exponent.native))

def pow[D <: DType, S <: ScalaType](
input: Tensor[D],
exponent: S
)(using OnlyOneBool[D, ScalaToDType[S]]): Tensor[Promoted[D, ScalaToDType[S]]] =
def pow[D <: DType, S <: ScalaType](input: Tensor[D], exponent: S)(using
@implicitNotFound(""""pow" not implemented for bool""")
ev1: Promoted[D, ScalaToDType[S]] NotEqual Bool,
@implicitNotFound(""""pow" not implemented for complex32""")
ev2: Promoted[D, ScalaToDType[S]] NotEqual Complex32
): Tensor[Promoted[D, ScalaToDType[S]]] =
Tensor(torchNative.pow(input.native, toScalar(exponent)))

def pow[S <: ScalaType, D <: DType](
input: S,
exponent: Tensor[D]
)(using OnlyOneBool[ScalaToDType[S], D]): Tensor[Promoted[ScalaToDType[S], D]] =
def pow[S <: ScalaType, D <: DType](input: S, exponent: Tensor[D])(using
@implicitNotFound(""""pow" not implemented for bool""")
ev1: Promoted[D, ScalaToDType[S]] NotEqual Bool,
@implicitNotFound(""""pow" not implemented for complex32""")
ev2: Promoted[D, ScalaToDType[S]] NotEqual Complex32
): Tensor[Promoted[ScalaToDType[S], D]] =
Tensor(torchNative.pow(toScalar(input), exponent.native))

// TODO Implement creation of QInts
Expand Down

0 comments on commit d047712

Please sign in to comment.