From 20770c47e344f34c0aa575b29ac3cfee6082d2de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Sun, 30 Jul 2023 01:23:56 +0200 Subject: [PATCH] Fix and improve initializing modules with other parameter types Remove broken copy and to(dtype) methods from module. Module type conversion will need rethinking the module design so for now it's fixed on creation. --- core/src/main/scala/torch/DType.scala | 33 ++++++++++++++++++- .../main/scala/torch/nn/modules/Module.scala | 25 +------------- .../torch/nn/modules/activation/Softmax.scala | 9 +++-- .../nn/modules/batchnorm/BatchNorm2d.scala | 2 +- .../scala/torch/nn/modules/conv/Conv2d.scala | 2 +- .../torch/nn/modules/linear/Identity.scala | 4 +-- .../torch/nn/modules/linear/Linear.scala | 4 +-- .../nn/modules/normalization/GroupNorm.scala | 6 ++-- .../modules/pooling/AdaptiveAvgPool2d.scala | 4 +-- .../torch/nn/modules/pooling/MaxPool2d.scala | 10 +++--- core/src/main/scala/torch/nn/package.scala | 1 - docs/modules.md | 2 +- docs/tutorial/buildmodel.md | 2 +- examples/src/main/scala/ImageClassifier.scala | 3 +- examples/src/main/scala/LeNet.scala | 10 +++--- .../scala/torchvision/models/resnet.scala | 16 +++++++-- 16 files changed, 76 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/torch/DType.scala b/core/src/main/scala/torch/DType.scala index 18b33c6d..06e835b9 100644 --- a/core/src/main/scala/torch/DType.scala +++ b/core/src/main/scala/torch/DType.scala @@ -209,6 +209,38 @@ private object Derive: val derive: Derive = Derive() export Derive.derive +/** Default tensor type. + * + * Defaults to float32 but can be overriden by providing providing the DType explicitly, or by + * overriding the default in the current scope through an import: + * + * Example: + * ```scala sc + * import torch.* + * + * // Default: + * nn.Linear(10, 10) // Linear[Float32] + * + * // Override explicitly: + * nn.Linear[BFloat16](10, 10) // Linear[BFloat16] + * + * // Override default: + * import nn.modules.Default.float64 + * nn.Linear(10, 10) // Linear[Float64] + * ``` + */ +trait Default[+D <: DType]: + def dtype: D +trait LowPriorityDefaults: + given float16: Default[Float16] = new Default[Float16] { def dtype = torch.float16 } + given bfloat16: Default[BFloat16] = new Default[BFloat16] { def dtype = torch.bfloat16 } + given float64: Default[Float64] = new Default[Float64] { def dtype = torch.float64 } + given complex32: Default[Complex32] = new Default[Complex32] { def dtype = torch.complex32 } + given complex64: Default[Complex64] = new Default[Complex64] { def dtype = torch.complex64 } + given complex128: Default[Complex128] = new Default[Complex128] { def dtype = torch.complex128 } +object Default extends LowPriorityDefaults: + given float32: Default[Float32] = new Default[Float32] { def dtype = torch.float32 } + /** DType combinations * */ type FloatNN = Float16 | Float32 | Float64 | BFloat16 @@ -452,4 +484,3 @@ transparent inline def deriveDType[T <: DType]: DType = case _: Float16 => float16 case _: Undefined => undefined case _: NumOptions => numoptions - case _ => float32 diff --git a/core/src/main/scala/torch/nn/modules/Module.scala b/core/src/main/scala/torch/nn/modules/Module.scala index ab0e3482..04ad69c6 100644 --- a/core/src/main/scala/torch/nn/modules/Module.scala +++ b/core/src/main/scala/torch/nn/modules/Module.scala @@ -70,11 +70,6 @@ abstract class Module { def namedModules: SeqMap[String, Module] = namedChildren.flatMap((name, module) => module.namedModules) - def copy(): this.type = - val clone = super.clone().asInstanceOf[Module] - clone._nativeModule = _nativeModule.clone(null) - clone.asInstanceOf[this.type] - def register[M <: Module](child: M)(using name: sourcecode.Name) = // println(s"registering ${name.value}:$child") childModules = childModules.updated(name.value, child) @@ -94,16 +89,7 @@ abstract class Module { def train(on: Boolean = true): Unit = nativeModule.train(on) def to(device: Device): this.type = - // val nativeCopy = nativeModule.clone() nativeModule.to(device.toNative, false) - // copy - // val clone: this.type = copy() - // clone.nativeModule = nativeCopy - this - - def to(dtype: DType, nonBlocking: Boolean = false): this.type = - val nativeCopy = nativeModule.clone(null) - nativeCopy.asModule.to(dtype.toScalarType, false) this def save(outputArchive: OutputArchive) = nativeModule.save(outputArchive) @@ -123,20 +109,11 @@ abstract class Module { doSummarize(0) } -/** Default tensor type for module parameters. - * - * Defaults to float32 but can be overriden by providing a given - */ -trait Default[+D <: DType]: - def dtype: D -object Default: - given f32: Default[Float32] = new Default[Float32] { def dtype = float32 } - trait HasParams[ParamType <: FloatNN | ComplexNN: Default] extends Module: override def parameters(recurse: Boolean): Seq[Tensor[ParamType]] = nativeModule.parameters(recurse).get().toSeq.map(Tensor.apply[ParamType]) override def parameters: Seq[Tensor[ParamType]] = parameters(recurse = true) - transparent inline def paramType = deriveDType[ParamType] + transparent inline def paramType: DType = summon[Default[ParamType]].dtype trait HasWeight[ParamType <: FloatNN | ComplexNN]: def weight: Tensor[ParamType] diff --git a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala index 986fbd2c..e1891076 100644 --- a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala +++ b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala @@ -14,7 +14,10 @@ * limitations under the License. */ -package torch.nn.modules.activation +package torch +package nn +package modules +package activation import org.bytedeco.pytorch import org.bytedeco.pytorch.SoftmaxImpl @@ -28,7 +31,7 @@ import torch.{DType, Tensor} * * When the input Tensor is a sparse tensor then the unspecifed values are treated as ``-inf``. */ -final class Softmax(dim: Int) extends Module: +final class Softmax[D <: DType: Default](dim: Int) extends TensorModule[D]: override val nativeModule: SoftmaxImpl = SoftmaxImpl(dim) - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala index 2a6ea11c..e106056f 100644 --- a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala +++ b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala @@ -24,7 +24,7 @@ import org.bytedeco.pytorch import sourcecode.Name import org.bytedeco.pytorch.BatchNorm2dImpl import org.bytedeco.pytorch.BatchNormOptions -import torch.nn.modules.{HasParams, HasWeight, TensorModule} +import torch.nn.modules.{HasParams, HasWeight} // format: off diff --git a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala index 32c88c7a..fed72615 100644 --- a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala +++ b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala @@ -26,7 +26,7 @@ import sourcecode.Name import torch.Tensor import torch.internal.NativeConverters.toNative import torch.nn.modules.conv.Conv2d.PaddingMode -import torch.nn.modules.{HasParams, TensorModule} +import torch.nn.modules.{HasParams} /** Applies a 2D convolution over an input signal composed of several input planes. * diff --git a/core/src/main/scala/torch/nn/modules/linear/Identity.scala b/core/src/main/scala/torch/nn/modules/linear/Identity.scala index b1d256a4..88920daa 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Identity.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Identity.scala @@ -28,7 +28,7 @@ import torch.{DType, Tensor} * * @group nn_linear */ -final class Identity(args: Any*) extends Module: +final class Identity[D <: DType: Default](args: Any*) extends TensorModule[D]: override val nativeModule: IdentityImpl = IdentityImpl() - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/linear/Linear.scala b/core/src/main/scala/torch/nn/modules/linear/Linear.scala index 65ab73f5..3d39a3c3 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Linear.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Linear.scala @@ -22,7 +22,7 @@ package linear import org.bytedeco.pytorch import org.bytedeco.pytorch.{LinearImpl, LinearOptions} import torch.Tensor -import torch.nn.modules.{HasParams, TensorModule} +import torch.nn.modules.{HasParams} /** Applies a linear transformation to the incoming data: $y = xA^T + b$ * @@ -53,7 +53,7 @@ final class Linear[ParamType <: FloatNN: Default]( bias: Boolean = true // dtype: ParamType = defaultDType[ParamType] ) extends HasParams[ParamType] - with (TensorModule[ParamType]): + with TensorModule[ParamType]: private val options = new LinearOptions(inFeatures, outFeatures) options.bias().put(bias) override private[torch] val nativeModule: LinearImpl = new LinearImpl(options) diff --git a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala index bcd588bf..d58ad75e 100644 --- a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala +++ b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala @@ -21,7 +21,6 @@ package normalization import org.bytedeco.pytorch import org.bytedeco.pytorch.{GroupNormImpl, GroupNormOptions} -import torch.nn.modules.TensorModule import torch.{DType, Tensor} /** Applies Group Normalization over a mini-batch of inputs @@ -36,12 +35,13 @@ import torch.{DType, Tensor} * a boolean value that when set to `true`, this module has learnable per-channel affine * parameters initialized to ones (for weights) and zeros (for biases) */ -final class GroupNorm[ParamType <: DType]( +final class GroupNorm[ParamType <: FloatNN | ComplexNN: Default]( numGroups: Int, numChannels: Int, eps: Double = 1e-05, affine: Boolean = true -) extends TensorModule[ParamType]: +) extends HasWeight[ParamType] + with TensorModule[ParamType]: private val options: GroupNormOptions = GroupNormOptions(numGroups, numChannels) options.eps().put(eps) options.affine().put(affine) diff --git a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala index a685cb5e..e3c6e415 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala @@ -32,7 +32,7 @@ import org.bytedeco.pytorch.LongOptional * The output is of size H x W, for any input size. The number of output features is equal to the * number of input planes. */ -final class AdaptiveAvgPool2d( +final class AdaptiveAvgPool2d[D <: BFloat16 | Float32 | Float64: Default]( outputSize: Int | Option[Int] | (Option[Int], Option[Int]) | (Int, Int) ) extends Module { @@ -49,7 +49,7 @@ final class AdaptiveAvgPool2d( nativeOutputSize.get(0) ) - def apply[D <: BFloat16 | Float32 | Float64](t: Tensor[D]): Tensor[D] = Tensor( + def apply(t: Tensor[D]): Tensor[D] = Tensor( nativeModule.forward(t.native) ) } diff --git a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala index 9f21cd70..10785797 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala @@ -23,19 +23,18 @@ import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch import org.bytedeco.pytorch.{MaxPool2dImpl, MaxPool2dOptions} import torch.internal.NativeConverters.toNative -import torch.nn.modules.{HasParams, TensorModule} +import torch.nn.modules.{HasParams} import torch.{BFloat16, Float32, Float64, Tensor} /** Applies a 2D max pooling over an input signal composed of several input planes. */ -final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default]( +final class MaxPool2d[D <: BFloat16 | Float32 | Float64: Default]( kernelSize: Int | (Int, Int), stride: Option[Int | (Int, Int)] = None, padding: Int | (Int, Int) = 0, dilation: Int | (Int, Int) = 1, // returnIndices: Boolean = false, ceilMode: Boolean = false -) extends HasParams[ParamType] - with TensorModule[ParamType]: +) extends TensorModule[D]: private val options: MaxPool2dOptions = MaxPool2dOptions(toNative(kernelSize)) stride.foreach(s => options.stride().put(toNative(s))) @@ -44,10 +43,9 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default]( options.ceil_mode().put(ceilMode) override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options) - nativeModule.to(paramType.toScalarType, false) override def toString(): String = s"MaxPool2d(kernelSize=$kernelSize, stride=$stride, padding=$padding, dilation=$dilation, ceilMode=$ceilMode)" - def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native)) + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) // TODO forward_with_indices diff --git a/core/src/main/scala/torch/nn/package.scala b/core/src/main/scala/torch/nn/package.scala index 2d899379..ded65ebb 100644 --- a/core/src/main/scala/torch/nn/package.scala +++ b/core/src/main/scala/torch/nn/package.scala @@ -25,7 +25,6 @@ package torch package object nn { export modules.Module - export modules.Default export modules.activation.Softmax export modules.activation.ReLU diff --git a/docs/modules.md b/docs/modules.md index e16f8603..ca2e2574 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -13,7 +13,7 @@ import torch.* import torch.nn import torch.nn.functional as F -class LeNet[D <: BFloat16 | Float32: nn.Default] extends nn.Module: +class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module: val conv1 = register(nn.Conv2d(1, 6, 5)) val pool = register(nn.MaxPool2d((2, 2))) val conv2 = register(nn.Conv2d(6, 16, 5)) diff --git a/docs/tutorial/buildmodel.md b/docs/tutorial/buildmodel.md index 0d0eb6fb..5a01188d 100644 --- a/docs/tutorial/buildmodel.md +++ b/docs/tutorial/buildmodel.md @@ -78,7 +78,7 @@ We get the prediction probabilities by passing it through an instance of the ``n ```scala mdoc val X = torch.rand(Seq(1, 28, 28), device=device) val logits = model(X) -val predProbab = nn.Softmax(dim=1)(logits) +val predProbab = nn.Softmax(dim=1).apply(logits) val yPred = predProbab.argmax(1) println(s"Predicted class: $yPred") ``` diff --git a/examples/src/main/scala/ImageClassifier.scala b/examples/src/main/scala/ImageClassifier.scala index f9d2a8b3..a363db1c 100644 --- a/examples/src/main/scala/ImageClassifier.scala +++ b/examples/src/main/scala/ImageClassifier.scala @@ -16,14 +16,13 @@ //> using scala "3.3" //> using repository "sonatype-s01:snapshots" -//> using lib "dev.storch::vision:0.0-ab8d84c-SNAPSHOT" +//> using lib "dev.storch::vision:0.0-795485b-SNAPSHOT" //> using lib "me.tongfei:progressbar:0.9.5" //> using lib "com.github.alexarchambault::case-app:2.1.0-M24" //> using lib "org.scala-lang.modules::scala-parallel-collections:1.0.4" // replace with pytorch-platform-gpu if you have a CUDA capable GPU //> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.9" // enable for CUDA support -////> using lib "org.bytedeco:cuda-platform:12.1-8.9-1.5.9" ////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.9" import Commands.* diff --git a/examples/src/main/scala/LeNet.scala b/examples/src/main/scala/LeNet.scala index 8204fdaf..bcc7aa7f 100644 --- a/examples/src/main/scala/LeNet.scala +++ b/examples/src/main/scala/LeNet.scala @@ -16,18 +16,16 @@ //> using scala "3.3" //> using repository "sonatype-s01:snapshots" -//> using lib "dev.storch::vision:0.0-ab8d84c-SNAPSHOT" +//> using lib "dev.storch::vision:0.0-795485b-SNAPSHOT" // replace with pytorch-platform-gpu if you have a CUDA capable GPU //> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.9" // enable for CUDA support -////> using lib "org.bytedeco:cuda-platform:12.1-8.9-1.5.9" ////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.9" import torch.* import torch.nn.functional as F import torch.optim.Adam import org.bytedeco.pytorch.OutputArchive -import torch.nn.modules.Default import torchvision.datasets.MNIST import scala.util.Random import java.nio.file.Paths @@ -35,9 +33,11 @@ import torch.Device.CUDA import scala.util.Using import org.bytedeco.javacpp.PointerScope import torch.Device.CPU +import torch.nn.modules.HasParams + +// Define the model architecture +class LeNet[D <: BFloat16 | Float32: Default] extends HasParams[D] { -// define model architecture -class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module { val conv1 = register(nn.Conv2d(1, 6, 5)) val pool = register(nn.MaxPool2d((2, 2))) val conv2 = register(nn.Conv2d(6, 16, 5)) diff --git a/vision/src/main/scala/torchvision/models/resnet.scala b/vision/src/main/scala/torchvision/models/resnet.scala index a5cd7115..3f4ca09f 100644 --- a/vision/src/main/scala/torchvision/models/resnet.scala +++ b/vision/src/main/scala/torchvision/models/resnet.scala @@ -17,7 +17,18 @@ package torchvision package models -import torch.{BFloat16, ComplexNN, DType, Float32, Float32Tensor, Float64, FloatNN, Tensor, nn} +import torch.{ + BFloat16, + ComplexNN, + DType, + Default, + Float32, + Float32Tensor, + Float64, + FloatNN, + Tensor, + nn +} import torch.nn.init.{Mode, NonLinearity, constant_, kaimingNormal_} import scala.collection.mutable @@ -28,12 +39,13 @@ import sourcecode.Name import torch.nn.modules.activation.ReLU import torch.nn.modules.conv.Conv2d import torch.nn.modules.pooling.{AdaptiveAvgPool2d, MaxPool2d} -import torch.nn.modules.{Default, HasWeight, Module, TensorModule} +import torch.nn.modules.{HasWeight, Module} import torchvision.transforms.* import scala.util.Using import com.sksamuel.scrimage.ImmutableImage import torch.Int32 +import torch.nn.modules.TensorModule /** ResNet architecture implementations *