Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix module initialization with other dtypes and simplify module registration #47

Merged
merged 2 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion core/src/main/scala/torch/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,38 @@ private object Derive:
val derive: Derive = Derive()
export Derive.derive

/** Default tensor type.
*
* Defaults to float32 but can be overriden by providing providing the DType explicitly, or by
* overriding the default in the current scope through an import:
*
* Example:
* ```scala sc
* import torch.*
*
* // Default:
* nn.Linear(10, 10) // Linear[Float32]
*
* // Override explicitly:
* nn.Linear[BFloat16](10, 10) // Linear[BFloat16]
*
* // Override default:
* import Default.float64
* nn.Linear(10, 10) // Linear[Float64]
* ```
*/
trait Default[+D <: DType]:
def dtype: D
trait LowPriorityDefaults:
given float16: Default[Float16] = new Default[Float16] { def dtype = torch.float16 }
given bfloat16: Default[BFloat16] = new Default[BFloat16] { def dtype = torch.bfloat16 }
given float64: Default[Float64] = new Default[Float64] { def dtype = torch.float64 }
given complex32: Default[Complex32] = new Default[Complex32] { def dtype = torch.complex32 }
given complex64: Default[Complex64] = new Default[Complex64] { def dtype = torch.complex64 }
given complex128: Default[Complex128] = new Default[Complex128] { def dtype = torch.complex128 }
object Default extends LowPriorityDefaults:
given float32: Default[Float32] = new Default[Float32] { def dtype = torch.float32 }

/** DType combinations * */
type FloatNN = Float16 | Float32 | Float64 | BFloat16

Expand Down Expand Up @@ -452,4 +484,3 @@ transparent inline def deriveDType[T <: DType]: DType =
case _: Float16 => float16
case _: Undefined => undefined
case _: NumOptions => numoptions
case _ => float32
34 changes: 3 additions & 31 deletions core/src/main/scala/torch/nn/modules/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,10 @@ abstract class Module {
def namedModules: SeqMap[String, Module] =
namedChildren.flatMap((name, module) => module.namedModules)

def copy(): this.type =
val clone = super.clone().asInstanceOf[Module]
clone._nativeModule = _nativeModule.clone(null)
clone.asInstanceOf[this.type]

protected[torch] def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def register[M <: Module](child: M)(using name: sourcecode.Name) =
// println(s"registering ${name.value}:$child")
childModules = childModules.updated(name.value, child)
child.registerWithParent(this.nativeModule)
nativeModule.register_module(name.value, child.nativeModule)
child

def register[D <: DType](t: Tensor[D], requiresGrad: Boolean = true)(using
Expand All @@ -99,16 +89,7 @@ abstract class Module {
def train(on: Boolean = true): Unit = nativeModule.train(on)

def to(device: Device): this.type =
// val nativeCopy = nativeModule.clone()
nativeModule.asModule.to(device.toNative, false)
// copy
// val clone: this.type = copy()
// clone.nativeModule = nativeCopy
this

def to(dtype: DType, nonBlocking: Boolean = false): this.type =
val nativeCopy = nativeModule.clone(null)
nativeCopy.asModule.to(dtype.toScalarType, false)
nativeModule.to(device.toNative, false)
this

def save(outputArchive: OutputArchive) = nativeModule.save(outputArchive)
Expand All @@ -128,20 +109,11 @@ abstract class Module {
doSummarize(0)
}

/** Default tensor type for module parameters.
*
* Defaults to float32 but can be overriden by providing a given
*/
trait Default[+D <: DType]:
def dtype: D
object Default:
given f32: Default[Float32] = new Default[Float32] { def dtype = float32 }

trait HasParams[ParamType <: FloatNN | ComplexNN: Default] extends Module:
override def parameters(recurse: Boolean): Seq[Tensor[ParamType]] =
nativeModule.parameters(recurse).get().toSeq.map(Tensor.apply[ParamType])
override def parameters: Seq[Tensor[ParamType]] = parameters(recurse = true)
transparent inline def paramType = deriveDType[ParamType]
transparent inline def paramType: DType = summon[Default[ParamType]].dtype

trait HasWeight[ParamType <: FloatNN | ComplexNN]:
def weight: Tensor[ParamType]
Expand Down
5 changes: 0 additions & 5 deletions core/src/main/scala/torch/nn/modules/activation/ReLU.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ final class ReLU[D <: DType: Default](inplace: Boolean = false) extends TensorMo

override protected[torch] val nativeModule: ReLUImpl = ReLUImpl()

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))

override def toString = getClass().getSimpleName()
14 changes: 6 additions & 8 deletions core/src/main/scala/torch/nn/modules/activation/Softmax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
* limitations under the License.
*/

package torch.nn.modules.activation
package torch
package nn
package modules
package activation

import org.bytedeco.pytorch
import org.bytedeco.pytorch.SoftmaxImpl
Expand All @@ -28,12 +31,7 @@ import torch.{DType, Tensor}
*
* When the input Tensor is a sparse tensor then the unspecifed values are treated as ``-inf``.
*/
final class Softmax(dim: Int) extends Module:
final class Softmax[D <: DType: Default](dim: Int) extends TensorModule[D]:
override val nativeModule: SoftmaxImpl = SoftmaxImpl(dim)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.bytedeco.pytorch
import sourcecode.Name
import org.bytedeco.pytorch.BatchNorm2dImpl
import org.bytedeco.pytorch.BatchNormOptions
import torch.nn.modules.{HasParams, HasWeight, TensorModule}
import torch.nn.modules.{HasParams, HasWeight}


// format: off
Expand Down Expand Up @@ -119,12 +119,7 @@ final class BatchNorm2d[ParamType <: FloatNN | ComplexNN: Default](
options.track_running_stats().put(trackRunningStats)

override private[torch] val nativeModule: BatchNorm2dImpl = BatchNorm2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)
nativeModule.to(paramType.toScalarType, false)

// TODO weight, bias etc. are undefined if affine = false. We need to take that into account
val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight)
Expand Down
9 changes: 2 additions & 7 deletions core/src/main/scala/torch/nn/modules/conv/Conv2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import sourcecode.Name
import torch.Tensor
import torch.internal.NativeConverters.toNative
import torch.nn.modules.conv.Conv2d.PaddingMode
import torch.nn.modules.{HasParams, TensorModule}
import torch.nn.modules.{HasParams}

/** Applies a 2D convolution over an input signal composed of several input planes.
*
Expand Down Expand Up @@ -59,12 +59,7 @@ final class Conv2d[ParamType <: FloatNN | ComplexNN: Default](
options.padding_mode().put(paddingModeNative)

override private[torch] val nativeModule: Conv2dImpl = Conv2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)
nativeModule.to(paramType.toScalarType, false)

def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native))

Expand Down
5 changes: 0 additions & 5 deletions core/src/main/scala/torch/nn/modules/flatten/Flatten.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ final class Flatten[D <: DType: Default](startDim: Int = 1, endDim: Int = -1)

override val nativeModule: FlattenImpl = FlattenImpl(options)

override def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))

override def toString = getClass().getSimpleName()
9 changes: 2 additions & 7 deletions core/src/main/scala/torch/nn/modules/linear/Identity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@ import torch.{DType, Tensor}
*
* @group nn_linear
*/
final class Identity(args: Any*) extends Module:
final class Identity[D <: DType: Default](args: Any*) extends TensorModule[D]:
override val nativeModule: IdentityImpl = IdentityImpl()

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
11 changes: 3 additions & 8 deletions core/src/main/scala/torch/nn/modules/linear/Linear.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package linear
import org.bytedeco.pytorch
import org.bytedeco.pytorch.{LinearImpl, LinearOptions}
import torch.Tensor
import torch.nn.modules.{HasParams, TensorModule}
import torch.nn.modules.{HasParams}

/** Applies a linear transformation to the incoming data: $y = xA^T + b$
*
Expand Down Expand Up @@ -53,16 +53,11 @@ final class Linear[ParamType <: FloatNN: Default](
bias: Boolean = true
// dtype: ParamType = defaultDType[ParamType]
) extends HasParams[ParamType]
with (TensorModule[ParamType]):
with TensorModule[ParamType]:
private val options = new LinearOptions(inFeatures, outFeatures)
options.bias().put(bias)
override private[torch] val nativeModule: LinearImpl = new LinearImpl(options)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)
nativeModule.to(paramType.toScalarType, false)

def apply(input: Tensor[ParamType]): Tensor[ParamType] = Tensor(
nativeModule.forward(input.native)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ package normalization

import org.bytedeco.pytorch
import org.bytedeco.pytorch.{GroupNormImpl, GroupNormOptions}
import torch.nn.modules.TensorModule
import torch.{DType, Tensor}

/** Applies Group Normalization over a mini-batch of inputs
Expand All @@ -36,23 +35,19 @@ import torch.{DType, Tensor}
* a boolean value that when set to `true`, this module has learnable per-channel affine
* parameters initialized to ones (for weights) and zeros (for biases)
*/
final class GroupNorm[ParamType <: DType](
final class GroupNorm[ParamType <: FloatNN | ComplexNN: Default](
numGroups: Int,
numChannels: Int,
eps: Double = 1e-05,
affine: Boolean = true
) extends TensorModule[ParamType]:
) extends HasWeight[ParamType]
with TensorModule[ParamType]:
private val options: GroupNormOptions = GroupNormOptions(numGroups, numChannels)
options.eps().put(eps)
options.affine().put(affine)

override private[torch] val nativeModule: GroupNormImpl = GroupNormImpl(options)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight)
val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.bytedeco.pytorch.LongOptional
* The output is of size H x W, for any input size. The number of output features is equal to the
* number of input planes.
*/
final class AdaptiveAvgPool2d(
final class AdaptiveAvgPool2d[D <: BFloat16 | Float32 | Float64: Default](
outputSize: Int | Option[Int] | (Option[Int], Option[Int]) | (Int, Int)
) extends Module {

Expand All @@ -45,16 +45,11 @@ final class AdaptiveAvgPool2d(
case x: Option[Int] =>
new LongOptionalVector(x.toOptional, x.toOptional)

override private[torch] val nativeModule: AdaptiveAvgPool2dImpl = AdaptiveAvgPool2dImpl(
override protected[torch] val nativeModule: AdaptiveAvgPool2dImpl = AdaptiveAvgPool2dImpl(
nativeOutputSize.get(0)
)

override def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply[D <: BFloat16 | Float32 | Float64](t: Tensor[D]): Tensor[D] = Tensor(
def apply(t: Tensor[D]): Tensor[D] = Tensor(
nativeModule.forward(t.native)
)
}
15 changes: 4 additions & 11 deletions core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@ import org.bytedeco.javacpp.LongPointer
import org.bytedeco.pytorch
import org.bytedeco.pytorch.{MaxPool2dImpl, MaxPool2dOptions}
import torch.internal.NativeConverters.toNative
import torch.nn.modules.{HasParams, TensorModule}
import torch.nn.modules.{HasParams}
import torch.{BFloat16, Float32, Float64, Tensor}

/** Applies a 2D max pooling over an input signal composed of several input planes. */
final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default](
final class MaxPool2d[D <: BFloat16 | Float32 | Float64: Default](
kernelSize: Int | (Int, Int),
stride: Option[Int | (Int, Int)] = None,
padding: Int | (Int, Int) = 0,
dilation: Int | (Int, Int) = 1,
// returnIndices: Boolean = false,
ceilMode: Boolean = false
) extends HasParams[ParamType]
with TensorModule[ParamType]:
) extends TensorModule[D]:

private val options: MaxPool2dOptions = MaxPool2dOptions(toNative(kernelSize))
stride.foreach(s => options.stride().put(toNative(s)))
Expand All @@ -44,15 +43,9 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default](
options.ceil_mode().put(ceilMode)

override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

override def toString(): String =
s"MaxPool2d(kernelSize=$kernelSize, stride=$stride, padding=$padding, dilation=$dilation, ceilMode=$ceilMode)"

def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native))
def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
// TODO forward_with_indices
1 change: 0 additions & 1 deletion core/src/main/scala/torch/nn/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ package torch
package object nn {

export modules.Module
export modules.Default

export modules.activation.Softmax
export modules.activation.ReLU
Expand Down
2 changes: 1 addition & 1 deletion docs/modules.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import torch.*
import torch.nn
import torch.nn.functional as F

class LeNet[D <: BFloat16 | Float32: nn.Default] extends nn.Module:
class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module:
val conv1 = register(nn.Conv2d(1, 6, 5))
val pool = register(nn.MaxPool2d((2, 2)))
val conv2 = register(nn.Conv2d(6, 16, 5))
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorial/buildmodel.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ We get the prediction probabilities by passing it through an instance of the ``n
```scala mdoc
val X = torch.rand(Seq(1, 28, 28), device=device)
val logits = model(X)
val predProbab = nn.Softmax(dim=1)(logits)
val predProbab = nn.Softmax(dim=1).apply(logits)
val yPred = predProbab.argmax(1)
println(s"Predicted class: $yPred")
```
Expand Down
3 changes: 1 addition & 2 deletions examples/src/main/scala/ImageClassifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using lib "dev.storch::vision:0.0-ab8d84c-SNAPSHOT"
//> using lib "dev.storch::vision:0.0-795485b-SNAPSHOT"
//> using lib "me.tongfei:progressbar:0.9.5"
//> using lib "com.github.alexarchambault::case-app:2.1.0-M24"
//> using lib "org.scala-lang.modules::scala-parallel-collections:1.0.4"
// replace with pytorch-platform-gpu if you have a CUDA capable GPU
//> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.9"
// enable for CUDA support
////> using lib "org.bytedeco:cuda-platform:12.1-8.9-1.5.9"
////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.9"

import Commands.*
Expand Down
Loading