Skip to content

Commit

Permalink
Implement reduction ops
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Jun 18, 2023
1 parent 4d8a82f commit 6ae3bb8
Show file tree
Hide file tree
Showing 6 changed files with 1,139 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@
* limitations under the License.
*/

package object torch {}
/** @groupname pointwise_ops Pointwise Ops
* @groupname reduction_ops Reduction Ops
*/
package object torch
3 changes: 3 additions & 0 deletions core/src/main/scala/torch/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool]

/* Evidence used in operations where at least one Float is required */
type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN

/* Evidence used in operations where at least one Float or Complex is required */
type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN
92 changes: 56 additions & 36 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,48 +34,68 @@ import org.bytedeco.pytorch.GenericDict
import org.bytedeco.pytorch.GenericDictIterator
import spire.math.Complex
import spire.math.UByte
import scala.annotation.targetName

private[torch] object NativeConverters:

inline def toOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match
inline def convertToOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match
case i: Option[T] => i.map(f(_)).orNull
case i: T => f(i)

def toOptional(l: Long | Option[Long]): pytorch.LongOptional =
toOptional(l, pytorch.LongOptional(_))
def toOptional(l: Double | Option[Double]): pytorch.DoubleOptional =
toOptional(l, pytorch.DoubleOptional(_))

def toOptional(l: Real | Option[Real]): pytorch.ScalarOptional =
toOptional(
l,
(r: Real) =>
val scalar = toScalar(r)
pytorch.ScalarOptional(scalar)
)

def toOptional[D <: DType](t: Tensor[D] | Option[Tensor[D]]): TensorOptional =
toOptional(t, t => pytorch.TensorOptional(t.native))

def toArray(i: Long | (Long, Long)) = i match
case i: Long => Array(i)
case (i, j) => Array(i, j)

def toNative(input: Int | (Int, Int)) = input match
case (h, w) => LongPointer(Array(h.toLong, w.toLong)*)
case x: Int => LongPointer(Array(x.toLong, x.toLong)*)

def toScalar(x: ScalaType): pytorch.Scalar = x match
case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte)
case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item()
case x: Byte => pytorch.Scalar(x)
case x: Short => pytorch.Scalar(x)
case x: Int => pytorch.Scalar(x)
case x: Long => pytorch.Scalar(x)
case x: Float => pytorch.Scalar(x)
case x: Double => pytorch.Scalar(x)
case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item()
case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item()
extension (l: Long | Option[Long])
def toOptional: pytorch.LongOptional = convertToOptional(l, pytorch.LongOptional(_))

extension (l: Double | Option[Double])
def toOptional: pytorch.DoubleOptional = convertToOptional(l, pytorch.DoubleOptional(_))

extension (l: Real | Option[Real])
def toOptional: pytorch.ScalarOptional =
convertToOptional(
l,
(r: Real) =>
val scalar = toScalar(r)
pytorch.ScalarOptional(scalar)
)

extension [D <: DType](t: Tensor[D] | Option[Tensor[D]])
def toOptional: TensorOptional =
convertToOptional(t, t => pytorch.TensorOptional(t.native))

extension (i: Long | (Long, Long))
def toArray = i match
case i: Long => Array(i)
case (i, j) => Array(i, j)

extension (i: Int | Seq[Int])
@targetName("intOrIntSeqToArray")
def toArray: Array[Long] = i match
case i: Int => Array(i.toLong)
case i: Seq[Int] => i.map(_.toLong).toArray

extension (i: Long | Seq[Long])
@targetName("longOrLongSeqToArray")
def toArray: Array[Long] = i match
case i: Long => Array(i)
case i: Seq[Long] => i.toArray

extension (input: Int | (Int, Int))
def toNative = input match
case (h, w) => LongPointer(Array(h.toLong, w.toLong)*)
case x: Int => LongPointer(Array(x.toLong, x.toLong)*)

extension (x: ScalaType)
def toScalar: pytorch.Scalar = x match
case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte)
case x: UByte => Tensor(x.toInt).to(dtype = uint8).native.item()
case x: Byte => pytorch.Scalar(x)
case x: Short => pytorch.Scalar(x)
case x: Int => pytorch.Scalar(x)
case x: Long => pytorch.Scalar(x)
case x: Float => pytorch.Scalar(x)
case x: Double => pytorch.Scalar(x)
case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item()
case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item()

def tensorOptions(
dtype: DType,
layout: Layout,
Expand Down
Loading

0 comments on commit 6ae3bb8

Please sign in to comment.