diff --git a/core/src/main/scala/torch/DType.scala b/core/src/main/scala/torch/DType.scala index 18b33c6d..06e835b9 100644 --- a/core/src/main/scala/torch/DType.scala +++ b/core/src/main/scala/torch/DType.scala @@ -209,6 +209,38 @@ private object Derive: val derive: Derive = Derive() export Derive.derive +/** Default tensor type. + * + * Defaults to float32 but can be overriden by providing providing the DType explicitly, or by + * overriding the default in the current scope through an import: + * + * Example: + * ```scala sc + * import torch.* + * + * // Default: + * nn.Linear(10, 10) // Linear[Float32] + * + * // Override explicitly: + * nn.Linear[BFloat16](10, 10) // Linear[BFloat16] + * + * // Override default: + * import nn.modules.Default.float64 + * nn.Linear(10, 10) // Linear[Float64] + * ``` + */ +trait Default[+D <: DType]: + def dtype: D +trait LowPriorityDefaults: + given float16: Default[Float16] = new Default[Float16] { def dtype = torch.float16 } + given bfloat16: Default[BFloat16] = new Default[BFloat16] { def dtype = torch.bfloat16 } + given float64: Default[Float64] = new Default[Float64] { def dtype = torch.float64 } + given complex32: Default[Complex32] = new Default[Complex32] { def dtype = torch.complex32 } + given complex64: Default[Complex64] = new Default[Complex64] { def dtype = torch.complex64 } + given complex128: Default[Complex128] = new Default[Complex128] { def dtype = torch.complex128 } +object Default extends LowPriorityDefaults: + given float32: Default[Float32] = new Default[Float32] { def dtype = torch.float32 } + /** DType combinations * */ type FloatNN = Float16 | Float32 | Float64 | BFloat16 @@ -452,4 +484,3 @@ transparent inline def deriveDType[T <: DType]: DType = case _: Float16 => float16 case _: Undefined => undefined case _: NumOptions => numoptions - case _ => float32 diff --git a/core/src/main/scala/torch/nn/modules/Module.scala b/core/src/main/scala/torch/nn/modules/Module.scala index ab0e3482..04ad69c6 100644 --- a/core/src/main/scala/torch/nn/modules/Module.scala +++ b/core/src/main/scala/torch/nn/modules/Module.scala @@ -70,11 +70,6 @@ abstract class Module { def namedModules: SeqMap[String, Module] = namedChildren.flatMap((name, module) => module.namedModules) - def copy(): this.type = - val clone = super.clone().asInstanceOf[Module] - clone._nativeModule = _nativeModule.clone(null) - clone.asInstanceOf[this.type] - def register[M <: Module](child: M)(using name: sourcecode.Name) = // println(s"registering ${name.value}:$child") childModules = childModules.updated(name.value, child) @@ -94,16 +89,7 @@ abstract class Module { def train(on: Boolean = true): Unit = nativeModule.train(on) def to(device: Device): this.type = - // val nativeCopy = nativeModule.clone() nativeModule.to(device.toNative, false) - // copy - // val clone: this.type = copy() - // clone.nativeModule = nativeCopy - this - - def to(dtype: DType, nonBlocking: Boolean = false): this.type = - val nativeCopy = nativeModule.clone(null) - nativeCopy.asModule.to(dtype.toScalarType, false) this def save(outputArchive: OutputArchive) = nativeModule.save(outputArchive) @@ -123,20 +109,11 @@ abstract class Module { doSummarize(0) } -/** Default tensor type for module parameters. - * - * Defaults to float32 but can be overriden by providing a given - */ -trait Default[+D <: DType]: - def dtype: D -object Default: - given f32: Default[Float32] = new Default[Float32] { def dtype = float32 } - trait HasParams[ParamType <: FloatNN | ComplexNN: Default] extends Module: override def parameters(recurse: Boolean): Seq[Tensor[ParamType]] = nativeModule.parameters(recurse).get().toSeq.map(Tensor.apply[ParamType]) override def parameters: Seq[Tensor[ParamType]] = parameters(recurse = true) - transparent inline def paramType = deriveDType[ParamType] + transparent inline def paramType: DType = summon[Default[ParamType]].dtype trait HasWeight[ParamType <: FloatNN | ComplexNN]: def weight: Tensor[ParamType] diff --git a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala index 986fbd2c..e1891076 100644 --- a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala +++ b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala @@ -14,7 +14,10 @@ * limitations under the License. */ -package torch.nn.modules.activation +package torch +package nn +package modules +package activation import org.bytedeco.pytorch import org.bytedeco.pytorch.SoftmaxImpl @@ -28,7 +31,7 @@ import torch.{DType, Tensor} * * When the input Tensor is a sparse tensor then the unspecifed values are treated as ``-inf``. */ -final class Softmax(dim: Int) extends Module: +final class Softmax[D <: DType: Default](dim: Int) extends TensorModule[D]: override val nativeModule: SoftmaxImpl = SoftmaxImpl(dim) - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala index 2a6ea11c..e106056f 100644 --- a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala +++ b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala @@ -24,7 +24,7 @@ import org.bytedeco.pytorch import sourcecode.Name import org.bytedeco.pytorch.BatchNorm2dImpl import org.bytedeco.pytorch.BatchNormOptions -import torch.nn.modules.{HasParams, HasWeight, TensorModule} +import torch.nn.modules.{HasParams, HasWeight} // format: off diff --git a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala index 32c88c7a..fed72615 100644 --- a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala +++ b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala @@ -26,7 +26,7 @@ import sourcecode.Name import torch.Tensor import torch.internal.NativeConverters.toNative import torch.nn.modules.conv.Conv2d.PaddingMode -import torch.nn.modules.{HasParams, TensorModule} +import torch.nn.modules.{HasParams} /** Applies a 2D convolution over an input signal composed of several input planes. * diff --git a/core/src/main/scala/torch/nn/modules/linear/Identity.scala b/core/src/main/scala/torch/nn/modules/linear/Identity.scala index b1d256a4..88920daa 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Identity.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Identity.scala @@ -28,7 +28,7 @@ import torch.{DType, Tensor} * * @group nn_linear */ -final class Identity(args: Any*) extends Module: +final class Identity[D <: DType: Default](args: Any*) extends TensorModule[D]: override val nativeModule: IdentityImpl = IdentityImpl() - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/linear/Linear.scala b/core/src/main/scala/torch/nn/modules/linear/Linear.scala index 65ab73f5..3d39a3c3 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Linear.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Linear.scala @@ -22,7 +22,7 @@ package linear import org.bytedeco.pytorch import org.bytedeco.pytorch.{LinearImpl, LinearOptions} import torch.Tensor -import torch.nn.modules.{HasParams, TensorModule} +import torch.nn.modules.{HasParams} /** Applies a linear transformation to the incoming data: $y = xA^T + b$ * @@ -53,7 +53,7 @@ final class Linear[ParamType <: FloatNN: Default]( bias: Boolean = true // dtype: ParamType = defaultDType[ParamType] ) extends HasParams[ParamType] - with (TensorModule[ParamType]): + with TensorModule[ParamType]: private val options = new LinearOptions(inFeatures, outFeatures) options.bias().put(bias) override private[torch] val nativeModule: LinearImpl = new LinearImpl(options) diff --git a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala index bcd588bf..d58ad75e 100644 --- a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala +++ b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala @@ -21,7 +21,6 @@ package normalization import org.bytedeco.pytorch import org.bytedeco.pytorch.{GroupNormImpl, GroupNormOptions} -import torch.nn.modules.TensorModule import torch.{DType, Tensor} /** Applies Group Normalization over a mini-batch of inputs @@ -36,12 +35,13 @@ import torch.{DType, Tensor} * a boolean value that when set to `true`, this module has learnable per-channel affine * parameters initialized to ones (for weights) and zeros (for biases) */ -final class GroupNorm[ParamType <: DType]( +final class GroupNorm[ParamType <: FloatNN | ComplexNN: Default]( numGroups: Int, numChannels: Int, eps: Double = 1e-05, affine: Boolean = true -) extends TensorModule[ParamType]: +) extends HasWeight[ParamType] + with TensorModule[ParamType]: private val options: GroupNormOptions = GroupNormOptions(numGroups, numChannels) options.eps().put(eps) options.affine().put(affine) diff --git a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala index a685cb5e..e3c6e415 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala @@ -32,7 +32,7 @@ import org.bytedeco.pytorch.LongOptional * The output is of size H x W, for any input size. The number of output features is equal to the * number of input planes. */ -final class AdaptiveAvgPool2d( +final class AdaptiveAvgPool2d[D <: BFloat16 | Float32 | Float64: Default]( outputSize: Int | Option[Int] | (Option[Int], Option[Int]) | (Int, Int) ) extends Module { @@ -49,7 +49,7 @@ final class AdaptiveAvgPool2d( nativeOutputSize.get(0) ) - def apply[D <: BFloat16 | Float32 | Float64](t: Tensor[D]): Tensor[D] = Tensor( + def apply(t: Tensor[D]): Tensor[D] = Tensor( nativeModule.forward(t.native) ) } diff --git a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala index 9f21cd70..10785797 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala @@ -23,19 +23,18 @@ import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch import org.bytedeco.pytorch.{MaxPool2dImpl, MaxPool2dOptions} import torch.internal.NativeConverters.toNative -import torch.nn.modules.{HasParams, TensorModule} +import torch.nn.modules.{HasParams} import torch.{BFloat16, Float32, Float64, Tensor} /** Applies a 2D max pooling over an input signal composed of several input planes. */ -final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default]( +final class MaxPool2d[D <: BFloat16 | Float32 | Float64: Default]( kernelSize: Int | (Int, Int), stride: Option[Int | (Int, Int)] = None, padding: Int | (Int, Int) = 0, dilation: Int | (Int, Int) = 1, // returnIndices: Boolean = false, ceilMode: Boolean = false -) extends HasParams[ParamType] - with TensorModule[ParamType]: +) extends TensorModule[D]: private val options: MaxPool2dOptions = MaxPool2dOptions(toNative(kernelSize)) stride.foreach(s => options.stride().put(toNative(s))) @@ -44,10 +43,9 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default]( options.ceil_mode().put(ceilMode) override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options) - nativeModule.to(paramType.toScalarType, false) override def toString(): String = s"MaxPool2d(kernelSize=$kernelSize, stride=$stride, padding=$padding, dilation=$dilation, ceilMode=$ceilMode)" - def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native)) + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) // TODO forward_with_indices diff --git a/core/src/main/scala/torch/nn/package.scala b/core/src/main/scala/torch/nn/package.scala index 2d899379..ded65ebb 100644 --- a/core/src/main/scala/torch/nn/package.scala +++ b/core/src/main/scala/torch/nn/package.scala @@ -25,7 +25,6 @@ package torch package object nn { export modules.Module - export modules.Default export modules.activation.Softmax export modules.activation.ReLU diff --git a/docs/modules.md b/docs/modules.md index e16f8603..ca2e2574 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -13,7 +13,7 @@ import torch.* import torch.nn import torch.nn.functional as F -class LeNet[D <: BFloat16 | Float32: nn.Default] extends nn.Module: +class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module: val conv1 = register(nn.Conv2d(1, 6, 5)) val pool = register(nn.MaxPool2d((2, 2))) val conv2 = register(nn.Conv2d(6, 16, 5)) diff --git a/docs/tutorial/buildmodel.md b/docs/tutorial/buildmodel.md index 0d0eb6fb..5a01188d 100644 --- a/docs/tutorial/buildmodel.md +++ b/docs/tutorial/buildmodel.md @@ -78,7 +78,7 @@ We get the prediction probabilities by passing it through an instance of the ``n ```scala mdoc val X = torch.rand(Seq(1, 28, 28), device=device) val logits = model(X) -val predProbab = nn.Softmax(dim=1)(logits) +val predProbab = nn.Softmax(dim=1).apply(logits) val yPred = predProbab.argmax(1) println(s"Predicted class: $yPred") ``` diff --git a/examples/src/main/scala/ImageClassifier.scala b/examples/src/main/scala/ImageClassifier.scala index f9d2a8b3..a363db1c 100644 --- a/examples/src/main/scala/ImageClassifier.scala +++ b/examples/src/main/scala/ImageClassifier.scala @@ -16,14 +16,13 @@ //> using scala "3.3" //> using repository "sonatype-s01:snapshots" -//> using lib "dev.storch::vision:0.0-ab8d84c-SNAPSHOT" +//> using lib "dev.storch::vision:0.0-795485b-SNAPSHOT" //> using lib "me.tongfei:progressbar:0.9.5" //> using lib "com.github.alexarchambault::case-app:2.1.0-M24" //> using lib "org.scala-lang.modules::scala-parallel-collections:1.0.4" // replace with pytorch-platform-gpu if you have a CUDA capable GPU //> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.9" // enable for CUDA support -////> using lib "org.bytedeco:cuda-platform:12.1-8.9-1.5.9" ////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.9" import Commands.* diff --git a/examples/src/main/scala/LeNet.scala b/examples/src/main/scala/LeNet.scala index 8204fdaf..bcc7aa7f 100644 --- a/examples/src/main/scala/LeNet.scala +++ b/examples/src/main/scala/LeNet.scala @@ -16,18 +16,16 @@ //> using scala "3.3" //> using repository "sonatype-s01:snapshots" -//> using lib "dev.storch::vision:0.0-ab8d84c-SNAPSHOT" +//> using lib "dev.storch::vision:0.0-795485b-SNAPSHOT" // replace with pytorch-platform-gpu if you have a CUDA capable GPU //> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.9" // enable for CUDA support -////> using lib "org.bytedeco:cuda-platform:12.1-8.9-1.5.9" ////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.9" import torch.* import torch.nn.functional as F import torch.optim.Adam import org.bytedeco.pytorch.OutputArchive -import torch.nn.modules.Default import torchvision.datasets.MNIST import scala.util.Random import java.nio.file.Paths @@ -35,9 +33,11 @@ import torch.Device.CUDA import scala.util.Using import org.bytedeco.javacpp.PointerScope import torch.Device.CPU +import torch.nn.modules.HasParams + +// Define the model architecture +class LeNet[D <: BFloat16 | Float32: Default] extends HasParams[D] { -// define model architecture -class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module { val conv1 = register(nn.Conv2d(1, 6, 5)) val pool = register(nn.MaxPool2d((2, 2))) val conv2 = register(nn.Conv2d(6, 16, 5)) diff --git a/vision/src/main/scala/torchvision/models/resnet.scala b/vision/src/main/scala/torchvision/models/resnet.scala index a5cd7115..3f4ca09f 100644 --- a/vision/src/main/scala/torchvision/models/resnet.scala +++ b/vision/src/main/scala/torchvision/models/resnet.scala @@ -17,7 +17,18 @@ package torchvision package models -import torch.{BFloat16, ComplexNN, DType, Float32, Float32Tensor, Float64, FloatNN, Tensor, nn} +import torch.{ + BFloat16, + ComplexNN, + DType, + Default, + Float32, + Float32Tensor, + Float64, + FloatNN, + Tensor, + nn +} import torch.nn.init.{Mode, NonLinearity, constant_, kaimingNormal_} import scala.collection.mutable @@ -28,12 +39,13 @@ import sourcecode.Name import torch.nn.modules.activation.ReLU import torch.nn.modules.conv.Conv2d import torch.nn.modules.pooling.{AdaptiveAvgPool2d, MaxPool2d} -import torch.nn.modules.{Default, HasWeight, Module, TensorModule} +import torch.nn.modules.{HasWeight, Module} import torchvision.transforms.* import scala.util.Using import com.sksamuel.scrimage.ImmutableImage import torch.Int32 +import torch.nn.modules.TensorModule /** ResNet architecture implementations *