Skip to content

Commit

Permalink
Merge pull request #23 from davoclavo/tensor-ops
Browse files Browse the repository at this point in the history
Tensor Pointwise ops
  • Loading branch information
sbrunk authored Jun 4, 2023
2 parents 875a245 + a81b76f commit 66266aa
Show file tree
Hide file tree
Showing 17 changed files with 2,171 additions and 50 deletions.
4 changes: 4 additions & 0 deletions .envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
watch_file devenv.nix
watch_file devenv.yaml
watch_file devenv.lock
eval "$(devenv print-dev-env)"
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ metals.sbt
*.worksheet.sc
/data/
.scala-build/

# Devenv
.devenv*
devenv.local.nix

53 changes: 53 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Contributing to Storch


## Install dependencies via nix + devenv

1. Install [nix](https://nixos.org) package manager

```bash
sh <(curl -L https://nixos.org/nix/install) --daemon
```

For more info, see https://nixos.org/download.html

2. Install [devenv](https://devenv.sh)

```bash
nix profile install --accept-flake-config github:cachix/devenv/latest
```

For more info, see: https://devenv.sh/getting-started/#installation

3. [Optionally] Install [direnv](https://direnv.net/)

This will load the specific environment variables upon `cd` into the storch folder

```bash
nix profile install 'nixpkgs#direnv'
```

4. Load environment

If you did not install direnv, run the following in the `storch` root folder:

```bash
devenv shell
```

If you installed direnv, just `cd` into storch


## Linting

### Manually run headerCrate + scalafmt on all files

```bash
sbt 'headerCreateAll ; scalafmtAll'
```

### Add useful git pre-push linting checks

```bash
cp git-hooks/pre-push-checks .git/hooks/ && chmod +x git-hooks/pre-push-checks
```
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ lazy val core = project
libraryDependencies ++= Seq(
"org.bytedeco" % "pytorch" % s"$pytorchVersion-${javaCppVersion.value}",
"org.typelevel" %% "spire" % "0.18.0",
"org.typelevel" %% "shapeless3-typeable" % "3.2.0",
"com.lihaoyi" %% "os-lib" % "0.9.0",
"com.lihaoyi" %% "sourcecode" % "0.3.0",
"dev.dirs" % "directories" % "26",
Expand Down
56 changes: 49 additions & 7 deletions core/src/main/scala/torch/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import java.nio.{
}
import scala.annotation.{targetName, unused}
import scala.reflect.ClassTag
import spire.math.*
import spire.math.{Complex, UByte}

import scala.compiletime.{erasedValue, summonFrom}

Expand Down Expand Up @@ -209,14 +209,31 @@ private object Derive:
val derive: Derive = Derive()
export Derive.derive

/** DType combinations * */
type FloatNN = Float16 | Float32 | Float64 | BFloat16

type IntNN = Int8 | UInt8 | Int16 | Int32 | Int64

type ComplexNN = Complex32 | Complex64 | Complex128

type ScalaType = Boolean | Byte | UByte | Short | Int | Long | Float | Double | Complex[Float] |
Complex[Double]
type BitwiseNN = Bool | IntNN

type NumericRealNN = IntNN | FloatNN

type RealNN = NumericRealNN | Bool

type NumericNN = NumericRealNN | ComplexNN

/** Scala type combinations * */
type NumericReal = Byte | UByte | Short | Int | Long | Float | Double

type Real = NumericReal | Boolean

type ComplexScala = Complex[Float] | Complex[Double]

type Numeric = NumericReal | ComplexScala

type ScalaType = Real | ComplexScala

type DTypeToScala[T <: DType] <: ScalaType = T match
case UInt8 => UByte
Expand Down Expand Up @@ -366,15 +383,40 @@ type Promoted[T <: DType, U <: DType] <: DType = (T, U) match
case (T, Complex128) => T
case _ => DType

/** Promoted type for tensor operations that always output numbers (e.g. `square`) */
type NumericPromoted[D <: DType] <: DType = D match
case Bool => Int64
case _ => D

/** Promoted type for tensor operations that always output floats (e.g. `sin`) */
type FloatPromoted[D <: DType] <: FloatNN = D match
case Float64 => Float64
case _ => Float32

/** Demoted type for complex to real type extractions (e.g. `imag`, `real`) */
type ComplexToReal[D <: DType] <: DType = D match
case Complex32 => Float16
case Complex64 => Float32
case Complex128 => Float64
case _ => D

/** Promoted type for tensor operations that always output full sized complex or real (e.g.
* `floatPower`)
*/
type ComplexPromoted[T <: DType, U <: DType] <: Float64 | Complex128 = (T, U) match
case (ComplexNN, U) => Complex128
case (T, ComplexNN) => Complex128
case _ => Float64

/** Promoted type for tensor division */
type Div[T <: DType, U <: DType] <: DType = (T, U) match
case (Bool | IntNN, Bool | IntNN) => Float32
case _ => Promoted[T, U]
case (BitwiseNN, BitwiseNN) => Float32
case _ => Promoted[T, U]

/** Promoted type for elementwise tensor sum */
type Sum[D <: DType] <: DType = D match
case Bool | IntNN => Int64
case D => D
case BitwiseNN => Int64
case D => D

private[torch] type TypedBuffer[T <: ScalaType] <: Buffer = T match
case Short => ShortBuffer
Expand Down
40 changes: 29 additions & 11 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,21 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
def item: DTypeToScala[D] =
import ScalarType.*
val out = native.dtype().toScalarType().intern() match
case Byte => UByte(native.item_int())
case Char => native.item_byte()
case Short => native.item_short()
case Int => native.item_int()
case Long => native.item_long()
case Half => native.item().toHalf.asFloat()
case Float => native.item_float()
case Double => native.item_double()
case ComplexHalf => ??? // TODO how to access complex scalar values?
case ComplexFloat => ???
case ComplexDouble => ???
case Byte => UByte(native.item_int())
case Char => native.item_byte()
case Short => native.item_short()
case Int => native.item_int()
case Long => native.item_long()
case Half => native.item().toHalf.asFloat()
case Float => native.item_float()
case Double => native.item_double()
case ComplexHalf => ??? // TODO how to access complex scalar values?
case ComplexFloat =>
val b = native.contiguous.createBuffer[FloatBuffer]
Complex(b.get(), b.get())
case ComplexDouble =>
val b = native.contiguous.createBuffer[DoubleBuffer]
Complex(b.get(), b.get())
case Bool => native.item().toBool
case QInt8 => native.item_byte()
case QUInt8 => native.item_short()
Expand Down Expand Up @@ -771,6 +775,20 @@ object Tensor:
case longs: Array[Long] => (new LongPointer(LongBuffer.wrap(longs)), int64)
case floats: Array[Float] => (new FloatPointer(FloatBuffer.wrap(floats)), float32)
case doubles: Array[Double] => (new DoublePointer(DoubleBuffer.wrap(doubles)), float64)
case complexFloatArray(complexFloats) =>
(
new FloatPointer(
FloatBuffer.wrap(complexFloats.flatMap(c => Array(c.real, c.imag)))
),
complex64
)
case complexDoubleArray(complexDoubles) =>
(
new DoublePointer(
DoubleBuffer.wrap(complexDoubles.flatMap(c => Array(c.real, c.imag)))
),
complex128
)
case _ => throw new IllegalArgumentException(s"Unsupported sequence type")
Tensor(
torchNative
Expand Down
47 changes: 47 additions & 0 deletions core/src/main/scala/torch/Types.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch

import shapeless3.typeable.{TypeCase, Typeable}
import shapeless3.typeable.syntax.typeable.*
import scala.util.NotGiven
import spire.math.Complex

/* Typeable instance for Array[T]
* NOTE: It needs to iterate through the whole array to validate casteability
*/
given iterableTypeable[T](using tt: Typeable[T]): Typeable[Array[T]] with
def castable(t: Any): Boolean =
t match
case (arr: Array[?]) =>
arr.forall(_.castable[T])
case _ => false
def describe = s"Array[${tt.describe}]"

/* TypeCase helpers to perform pattern matching on `Complex` higher kinded types */
val complexDoubleArray = TypeCase[Array[Complex[Double]]]
val complexFloatArray = TypeCase[Array[Complex[Float]]]

/* Type helper to describe inputs that accept Tensor or Real scalars */
type TensorOrReal[D <: RealNN] = Tensor[D] | Real

/* Evidence used in operations where Bool is accepted, but only on one of the two inputs, not both
*/
type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool]

/* Evidence used in operations where at least one Float is required */
type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN
33 changes: 23 additions & 10 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ package torch
package internal

import org.bytedeco.pytorch
import org.bytedeco.pytorch.ScalarTypeOptional
import org.bytedeco.pytorch.LayoutOptional
import org.bytedeco.pytorch.DeviceOptional
import org.bytedeco.pytorch.BoolOptional
import org.bytedeco.pytorch.LongOptional
import org.bytedeco.pytorch.TensorOptional
import org.bytedeco.pytorch.{
ScalarTypeOptional,
LayoutOptional,
DeviceOptional,
DoubleOptional,
BoolOptional,
LongOptional,
TensorOptional
}

import scala.reflect.Typeable
import org.bytedeco.javacpp.LongPointer
Expand All @@ -38,7 +41,18 @@ private[torch] object NativeConverters:
case i: Option[T] => i.map(f(_)).orNull
case i: T => f(i)

def toOptional(l: Long | Option[Long]): LongOptional = toOptional(l, pytorch.LongOptional(_))
def toOptional(l: Long | Option[Long]): pytorch.LongOptional =
toOptional(l, pytorch.LongOptional(_))
def toOptional(l: Double | Option[Double]): pytorch.DoubleOptional =
toOptional(l, pytorch.DoubleOptional(_))

def toOptional(l: Real | Option[Real]): pytorch.ScalarOptional =
toOptional(
l,
(r: Real) =>
val scalar = toScalar(r)
pytorch.ScalarOptional(scalar)
)

def toOptional[D <: DType](t: Tensor[D] | Option[Tensor[D]]): TensorOptional =
toOptional(t, t => pytorch.TensorOptional(t.native))
Expand All @@ -60,9 +74,8 @@ private[torch] object NativeConverters:
case x: Long => pytorch.Scalar(x)
case x: Float => pytorch.Scalar(x)
case x: Double => pytorch.Scalar(x)
case x @ Complex(r: Float, i: Float) => ???
case x @ Complex(r: Double, i: Double) => ???

case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item()
case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item()
def tensorOptions(
dtype: DType,
layout: Layout,
Expand Down
Loading

0 comments on commit 66266aa

Please sign in to comment.