Skip to content

Commit

Permalink
Fix and improve initializing modules with other parameter types
Browse files Browse the repository at this point in the history
Remove broken copy and to(dtype) methods from module.
Module type conversion will need rethinking the module design so for now it's fixed on creation.
  • Loading branch information
sbrunk committed Jul 29, 2023
1 parent 5f18d5a commit 20770c4
Show file tree
Hide file tree
Showing 16 changed files with 76 additions and 57 deletions.
33 changes: 32 additions & 1 deletion core/src/main/scala/torch/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,38 @@ private object Derive:
val derive: Derive = Derive()
export Derive.derive

/** Default tensor type.
*
* Defaults to float32 but can be overriden by providing providing the DType explicitly, or by
* overriding the default in the current scope through an import:
*
* Example:
* ```scala sc
* import torch.*
*
* // Default:
* nn.Linear(10, 10) // Linear[Float32]
*
* // Override explicitly:
* nn.Linear[BFloat16](10, 10) // Linear[BFloat16]
*
* // Override default:
* import nn.modules.Default.float64
* nn.Linear(10, 10) // Linear[Float64]
* ```
*/
trait Default[+D <: DType]:
def dtype: D
trait LowPriorityDefaults:
given float16: Default[Float16] = new Default[Float16] { def dtype = torch.float16 }
given bfloat16: Default[BFloat16] = new Default[BFloat16] { def dtype = torch.bfloat16 }
given float64: Default[Float64] = new Default[Float64] { def dtype = torch.float64 }
given complex32: Default[Complex32] = new Default[Complex32] { def dtype = torch.complex32 }
given complex64: Default[Complex64] = new Default[Complex64] { def dtype = torch.complex64 }
given complex128: Default[Complex128] = new Default[Complex128] { def dtype = torch.complex128 }
object Default extends LowPriorityDefaults:
given float32: Default[Float32] = new Default[Float32] { def dtype = torch.float32 }

/** DType combinations * */
type FloatNN = Float16 | Float32 | Float64 | BFloat16

Expand Down Expand Up @@ -452,4 +484,3 @@ transparent inline def deriveDType[T <: DType]: DType =
case _: Float16 => float16
case _: Undefined => undefined
case _: NumOptions => numoptions
case _ => float32
25 changes: 1 addition & 24 deletions core/src/main/scala/torch/nn/modules/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ abstract class Module {
def namedModules: SeqMap[String, Module] =
namedChildren.flatMap((name, module) => module.namedModules)

def copy(): this.type =
val clone = super.clone().asInstanceOf[Module]
clone._nativeModule = _nativeModule.clone(null)
clone.asInstanceOf[this.type]

def register[M <: Module](child: M)(using name: sourcecode.Name) =
// println(s"registering ${name.value}:$child")
childModules = childModules.updated(name.value, child)
Expand All @@ -94,16 +89,7 @@ abstract class Module {
def train(on: Boolean = true): Unit = nativeModule.train(on)

def to(device: Device): this.type =
// val nativeCopy = nativeModule.clone()
nativeModule.to(device.toNative, false)
// copy
// val clone: this.type = copy()
// clone.nativeModule = nativeCopy
this

def to(dtype: DType, nonBlocking: Boolean = false): this.type =
val nativeCopy = nativeModule.clone(null)
nativeCopy.asModule.to(dtype.toScalarType, false)
this

def save(outputArchive: OutputArchive) = nativeModule.save(outputArchive)
Expand All @@ -123,20 +109,11 @@ abstract class Module {
doSummarize(0)
}

/** Default tensor type for module parameters.
*
* Defaults to float32 but can be overriden by providing a given
*/
trait Default[+D <: DType]:
def dtype: D
object Default:
given f32: Default[Float32] = new Default[Float32] { def dtype = float32 }

trait HasParams[ParamType <: FloatNN | ComplexNN: Default] extends Module:
override def parameters(recurse: Boolean): Seq[Tensor[ParamType]] =
nativeModule.parameters(recurse).get().toSeq.map(Tensor.apply[ParamType])
override def parameters: Seq[Tensor[ParamType]] = parameters(recurse = true)
transparent inline def paramType = deriveDType[ParamType]
transparent inline def paramType: DType = summon[Default[ParamType]].dtype

trait HasWeight[ParamType <: FloatNN | ComplexNN]:
def weight: Tensor[ParamType]
Expand Down
9 changes: 6 additions & 3 deletions core/src/main/scala/torch/nn/modules/activation/Softmax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
* limitations under the License.
*/

package torch.nn.modules.activation
package torch
package nn
package modules
package activation

import org.bytedeco.pytorch
import org.bytedeco.pytorch.SoftmaxImpl
Expand All @@ -28,7 +31,7 @@ import torch.{DType, Tensor}
*
* When the input Tensor is a sparse tensor then the unspecifed values are treated as ``-inf``.
*/
final class Softmax(dim: Int) extends Module:
final class Softmax[D <: DType: Default](dim: Int) extends TensorModule[D]:
override val nativeModule: SoftmaxImpl = SoftmaxImpl(dim)

def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.bytedeco.pytorch
import sourcecode.Name
import org.bytedeco.pytorch.BatchNorm2dImpl
import org.bytedeco.pytorch.BatchNormOptions
import torch.nn.modules.{HasParams, HasWeight, TensorModule}
import torch.nn.modules.{HasParams, HasWeight}


// format: off
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/nn/modules/conv/Conv2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import sourcecode.Name
import torch.Tensor
import torch.internal.NativeConverters.toNative
import torch.nn.modules.conv.Conv2d.PaddingMode
import torch.nn.modules.{HasParams, TensorModule}
import torch.nn.modules.{HasParams}

/** Applies a 2D convolution over an input signal composed of several input planes.
*
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/torch/nn/modules/linear/Identity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import torch.{DType, Tensor}
*
* @group nn_linear
*/
final class Identity(args: Any*) extends Module:
final class Identity[D <: DType: Default](args: Any*) extends TensorModule[D]:
override val nativeModule: IdentityImpl = IdentityImpl()

def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
4 changes: 2 additions & 2 deletions core/src/main/scala/torch/nn/modules/linear/Linear.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package linear
import org.bytedeco.pytorch
import org.bytedeco.pytorch.{LinearImpl, LinearOptions}
import torch.Tensor
import torch.nn.modules.{HasParams, TensorModule}
import torch.nn.modules.{HasParams}

/** Applies a linear transformation to the incoming data: $y = xA^T + b$
*
Expand Down Expand Up @@ -53,7 +53,7 @@ final class Linear[ParamType <: FloatNN: Default](
bias: Boolean = true
// dtype: ParamType = defaultDType[ParamType]
) extends HasParams[ParamType]
with (TensorModule[ParamType]):
with TensorModule[ParamType]:
private val options = new LinearOptions(inFeatures, outFeatures)
options.bias().put(bias)
override private[torch] val nativeModule: LinearImpl = new LinearImpl(options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ package normalization

import org.bytedeco.pytorch
import org.bytedeco.pytorch.{GroupNormImpl, GroupNormOptions}
import torch.nn.modules.TensorModule
import torch.{DType, Tensor}

/** Applies Group Normalization over a mini-batch of inputs
Expand All @@ -36,12 +35,13 @@ import torch.{DType, Tensor}
* a boolean value that when set to `true`, this module has learnable per-channel affine
* parameters initialized to ones (for weights) and zeros (for biases)
*/
final class GroupNorm[ParamType <: DType](
final class GroupNorm[ParamType <: FloatNN | ComplexNN: Default](
numGroups: Int,
numChannels: Int,
eps: Double = 1e-05,
affine: Boolean = true
) extends TensorModule[ParamType]:
) extends HasWeight[ParamType]
with TensorModule[ParamType]:
private val options: GroupNormOptions = GroupNormOptions(numGroups, numChannels)
options.eps().put(eps)
options.affine().put(affine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.bytedeco.pytorch.LongOptional
* The output is of size H x W, for any input size. The number of output features is equal to the
* number of input planes.
*/
final class AdaptiveAvgPool2d(
final class AdaptiveAvgPool2d[D <: BFloat16 | Float32 | Float64: Default](
outputSize: Int | Option[Int] | (Option[Int], Option[Int]) | (Int, Int)
) extends Module {

Expand All @@ -49,7 +49,7 @@ final class AdaptiveAvgPool2d(
nativeOutputSize.get(0)
)

def apply[D <: BFloat16 | Float32 | Float64](t: Tensor[D]): Tensor[D] = Tensor(
def apply(t: Tensor[D]): Tensor[D] = Tensor(
nativeModule.forward(t.native)
)
}
10 changes: 4 additions & 6 deletions core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@ import org.bytedeco.javacpp.LongPointer
import org.bytedeco.pytorch
import org.bytedeco.pytorch.{MaxPool2dImpl, MaxPool2dOptions}
import torch.internal.NativeConverters.toNative
import torch.nn.modules.{HasParams, TensorModule}
import torch.nn.modules.{HasParams}
import torch.{BFloat16, Float32, Float64, Tensor}

/** Applies a 2D max pooling over an input signal composed of several input planes. */
final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default](
final class MaxPool2d[D <: BFloat16 | Float32 | Float64: Default](
kernelSize: Int | (Int, Int),
stride: Option[Int | (Int, Int)] = None,
padding: Int | (Int, Int) = 0,
dilation: Int | (Int, Int) = 1,
// returnIndices: Boolean = false,
ceilMode: Boolean = false
) extends HasParams[ParamType]
with TensorModule[ParamType]:
) extends TensorModule[D]:

private val options: MaxPool2dOptions = MaxPool2dOptions(toNative(kernelSize))
stride.foreach(s => options.stride().put(toNative(s)))
Expand All @@ -44,10 +43,9 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default](
options.ceil_mode().put(ceilMode)

override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options)
nativeModule.to(paramType.toScalarType, false)

override def toString(): String =
s"MaxPool2d(kernelSize=$kernelSize, stride=$stride, padding=$padding, dilation=$dilation, ceilMode=$ceilMode)"

def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native))
def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
// TODO forward_with_indices
1 change: 0 additions & 1 deletion core/src/main/scala/torch/nn/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ package torch
package object nn {

export modules.Module
export modules.Default

export modules.activation.Softmax
export modules.activation.ReLU
Expand Down
2 changes: 1 addition & 1 deletion docs/modules.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import torch.*
import torch.nn
import torch.nn.functional as F

class LeNet[D <: BFloat16 | Float32: nn.Default] extends nn.Module:
class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module:
val conv1 = register(nn.Conv2d(1, 6, 5))
val pool = register(nn.MaxPool2d((2, 2)))
val conv2 = register(nn.Conv2d(6, 16, 5))
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorial/buildmodel.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ We get the prediction probabilities by passing it through an instance of the ``n
```scala mdoc
val X = torch.rand(Seq(1, 28, 28), device=device)
val logits = model(X)
val predProbab = nn.Softmax(dim=1)(logits)
val predProbab = nn.Softmax(dim=1).apply(logits)
val yPred = predProbab.argmax(1)
println(s"Predicted class: $yPred")
```
Expand Down
3 changes: 1 addition & 2 deletions examples/src/main/scala/ImageClassifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using lib "dev.storch::vision:0.0-ab8d84c-SNAPSHOT"
//> using lib "dev.storch::vision:0.0-795485b-SNAPSHOT"
//> using lib "me.tongfei:progressbar:0.9.5"
//> using lib "com.github.alexarchambault::case-app:2.1.0-M24"
//> using lib "org.scala-lang.modules::scala-parallel-collections:1.0.4"
// replace with pytorch-platform-gpu if you have a CUDA capable GPU
//> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.9"
// enable for CUDA support
////> using lib "org.bytedeco:cuda-platform:12.1-8.9-1.5.9"
////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.9"

import Commands.*
Expand Down
10 changes: 5 additions & 5 deletions examples/src/main/scala/LeNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,28 @@

//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using lib "dev.storch::vision:0.0-ab8d84c-SNAPSHOT"
//> using lib "dev.storch::vision:0.0-795485b-SNAPSHOT"
// replace with pytorch-platform-gpu if you have a CUDA capable GPU
//> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.9"
// enable for CUDA support
////> using lib "org.bytedeco:cuda-platform:12.1-8.9-1.5.9"
////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.9"

import torch.*
import torch.nn.functional as F
import torch.optim.Adam
import org.bytedeco.pytorch.OutputArchive
import torch.nn.modules.Default
import torchvision.datasets.MNIST
import scala.util.Random
import java.nio.file.Paths
import torch.Device.CUDA
import scala.util.Using
import org.bytedeco.javacpp.PointerScope
import torch.Device.CPU
import torch.nn.modules.HasParams

// Define the model architecture
class LeNet[D <: BFloat16 | Float32: Default] extends HasParams[D] {

// define model architecture
class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module {
val conv1 = register(nn.Conv2d(1, 6, 5))
val pool = register(nn.MaxPool2d((2, 2)))
val conv2 = register(nn.Conv2d(6, 16, 5))
Expand Down
16 changes: 14 additions & 2 deletions vision/src/main/scala/torchvision/models/resnet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@
package torchvision
package models

import torch.{BFloat16, ComplexNN, DType, Float32, Float32Tensor, Float64, FloatNN, Tensor, nn}
import torch.{
BFloat16,
ComplexNN,
DType,
Default,
Float32,
Float32Tensor,
Float64,
FloatNN,
Tensor,
nn
}
import torch.nn.init.{Mode, NonLinearity, constant_, kaimingNormal_}

import scala.collection.mutable
Expand All @@ -28,12 +39,13 @@ import sourcecode.Name
import torch.nn.modules.activation.ReLU
import torch.nn.modules.conv.Conv2d
import torch.nn.modules.pooling.{AdaptiveAvgPool2d, MaxPool2d}
import torch.nn.modules.{Default, HasWeight, Module, TensorModule}
import torch.nn.modules.{HasWeight, Module}
import torchvision.transforms.*

import scala.util.Using
import com.sksamuel.scrimage.ImmutableImage
import torch.Int32
import torch.nn.modules.TensorModule

/** ResNet architecture implementations
*
Expand Down

0 comments on commit 20770c4

Please sign in to comment.