From 5f18d5a471886d6684142dba7874d7914ab880eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Fri, 28 Jul 2023 21:28:10 +0200 Subject: [PATCH 1/2] Simplify module registration after improvements in new presets --- core/src/main/scala/torch/nn/modules/Module.scala | 9 ++------- .../main/scala/torch/nn/modules/activation/ReLU.scala | 5 ----- .../main/scala/torch/nn/modules/activation/Softmax.scala | 5 ----- .../scala/torch/nn/modules/batchnorm/BatchNorm2d.scala | 7 +------ core/src/main/scala/torch/nn/modules/conv/Conv2d.scala | 7 +------ .../main/scala/torch/nn/modules/flatten/Flatten.scala | 5 ----- .../main/scala/torch/nn/modules/linear/Identity.scala | 5 ----- core/src/main/scala/torch/nn/modules/linear/Linear.scala | 7 +------ .../scala/torch/nn/modules/normalization/GroupNorm.scala | 5 ----- .../torch/nn/modules/pooling/AdaptiveAvgPool2d.scala | 7 +------ .../main/scala/torch/nn/modules/pooling/MaxPool2d.scala | 7 +------ 11 files changed, 7 insertions(+), 62 deletions(-) diff --git a/core/src/main/scala/torch/nn/modules/Module.scala b/core/src/main/scala/torch/nn/modules/Module.scala index 4db7b1e3..ab0e3482 100644 --- a/core/src/main/scala/torch/nn/modules/Module.scala +++ b/core/src/main/scala/torch/nn/modules/Module.scala @@ -75,15 +75,10 @@ abstract class Module { clone._nativeModule = _nativeModule.clone(null) clone.asInstanceOf[this.type] - protected[torch] def registerWithParent[T <: pytorch.Module](parent: T)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def register[M <: Module](child: M)(using name: sourcecode.Name) = // println(s"registering ${name.value}:$child") childModules = childModules.updated(name.value, child) - child.registerWithParent(this.nativeModule) + nativeModule.register_module(name.value, child.nativeModule) child def register[D <: DType](t: Tensor[D], requiresGrad: Boolean = true)(using @@ -100,7 +95,7 @@ abstract class Module { def to(device: Device): this.type = // val nativeCopy = nativeModule.clone() - nativeModule.asModule.to(device.toNative, false) + nativeModule.to(device.toNative, false) // copy // val clone: this.type = copy() // clone.nativeModule = nativeCopy diff --git a/core/src/main/scala/torch/nn/modules/activation/ReLU.scala b/core/src/main/scala/torch/nn/modules/activation/ReLU.scala index 1e299606..690d633e 100644 --- a/core/src/main/scala/torch/nn/modules/activation/ReLU.scala +++ b/core/src/main/scala/torch/nn/modules/activation/ReLU.scala @@ -34,11 +34,6 @@ final class ReLU[D <: DType: Default](inplace: Boolean = false) extends TensorMo override protected[torch] val nativeModule: ReLUImpl = ReLUImpl() - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) override def toString = getClass().getSimpleName() diff --git a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala index 1da2211b..986fbd2c 100644 --- a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala +++ b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala @@ -31,9 +31,4 @@ import torch.{DType, Tensor} final class Softmax(dim: Int) extends Module: override val nativeModule: SoftmaxImpl = SoftmaxImpl(dim) - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala index 9c25c128..2a6ea11c 100644 --- a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala +++ b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala @@ -119,12 +119,7 @@ final class BatchNorm2d[ParamType <: FloatNN | ComplexNN: Default]( options.track_running_stats().put(trackRunningStats) override private[torch] val nativeModule: BatchNorm2dImpl = BatchNorm2dImpl(options) - nativeModule.asModule.to(paramType.toScalarType, false) - - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) + nativeModule.to(paramType.toScalarType, false) // TODO weight, bias etc. are undefined if affine = false. We need to take that into account val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) diff --git a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala index f419c274..32c88c7a 100644 --- a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala +++ b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala @@ -59,12 +59,7 @@ final class Conv2d[ParamType <: FloatNN | ComplexNN: Default]( options.padding_mode().put(paddingModeNative) override private[torch] val nativeModule: Conv2dImpl = Conv2dImpl(options) - nativeModule.asModule.to(paramType.toScalarType, false) - - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) + nativeModule.to(paramType.toScalarType, false) def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/flatten/Flatten.scala b/core/src/main/scala/torch/nn/modules/flatten/Flatten.scala index 6c8b5844..34fc7b9c 100644 --- a/core/src/main/scala/torch/nn/modules/flatten/Flatten.scala +++ b/core/src/main/scala/torch/nn/modules/flatten/Flatten.scala @@ -59,11 +59,6 @@ final class Flatten[D <: DType: Default](startDim: Int = 1, endDim: Int = -1) override val nativeModule: FlattenImpl = FlattenImpl(options) - override def registerWithParent[T <: pytorch.Module](parent: T)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) override def toString = getClass().getSimpleName() diff --git a/core/src/main/scala/torch/nn/modules/linear/Identity.scala b/core/src/main/scala/torch/nn/modules/linear/Identity.scala index 2beef814..b1d256a4 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Identity.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Identity.scala @@ -31,9 +31,4 @@ import torch.{DType, Tensor} final class Identity(args: Any*) extends Module: override val nativeModule: IdentityImpl = IdentityImpl() - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/linear/Linear.scala b/core/src/main/scala/torch/nn/modules/linear/Linear.scala index 26ff8cc5..65ab73f5 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Linear.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Linear.scala @@ -57,12 +57,7 @@ final class Linear[ParamType <: FloatNN: Default]( private val options = new LinearOptions(inFeatures, outFeatures) options.bias().put(bias) override private[torch] val nativeModule: LinearImpl = new LinearImpl(options) - nativeModule.asModule.to(paramType.toScalarType, false) - - override def registerWithParent[T <: pytorch.Module](parent: T)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) + nativeModule.to(paramType.toScalarType, false) def apply(input: Tensor[ParamType]): Tensor[ParamType] = Tensor( nativeModule.forward(input.native) diff --git a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala index fd1bc268..bcd588bf 100644 --- a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala +++ b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala @@ -48,11 +48,6 @@ final class GroupNorm[ParamType <: DType]( override private[torch] val nativeModule: GroupNormImpl = GroupNormImpl(options) - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias) diff --git a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala index 04bf4ca8..a685cb5e 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala @@ -45,15 +45,10 @@ final class AdaptiveAvgPool2d( case x: Option[Int] => new LongOptionalVector(x.toOptional, x.toOptional) - override private[torch] val nativeModule: AdaptiveAvgPool2dImpl = AdaptiveAvgPool2dImpl( + override protected[torch] val nativeModule: AdaptiveAvgPool2dImpl = AdaptiveAvgPool2dImpl( nativeOutputSize.get(0) ) - override def registerWithParent[T <: pytorch.Module](parent: T)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply[D <: BFloat16 | Float32 | Float64](t: Tensor[D]): Tensor[D] = Tensor( nativeModule.forward(t.native) ) diff --git a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala index abed0ad7..9f21cd70 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala @@ -44,12 +44,7 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default]( options.ceil_mode().put(ceilMode) override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options) - nativeModule.asModule.to(paramType.toScalarType, false) - - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) + nativeModule.to(paramType.toScalarType, false) override def toString(): String = s"MaxPool2d(kernelSize=$kernelSize, stride=$stride, padding=$padding, dilation=$dilation, ceilMode=$ceilMode)" From b2b5a2f0796332fe74939bfe1b596b26472cd58e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Sun, 30 Jul 2023 01:23:56 +0200 Subject: [PATCH 2/2] Fix and improve initializing modules with other parameter types Remove broken copy and to(dtype) methods from module. Module type conversion will need rethinking the module design so for now it's fixed on creation. --- core/src/main/scala/torch/DType.scala | 33 ++++++++++++++++++- .../main/scala/torch/nn/modules/Module.scala | 25 +------------- .../torch/nn/modules/activation/Softmax.scala | 9 +++-- .../nn/modules/batchnorm/BatchNorm2d.scala | 2 +- .../scala/torch/nn/modules/conv/Conv2d.scala | 2 +- .../torch/nn/modules/linear/Identity.scala | 4 +-- .../torch/nn/modules/linear/Linear.scala | 4 +-- .../nn/modules/normalization/GroupNorm.scala | 6 ++-- .../modules/pooling/AdaptiveAvgPool2d.scala | 4 +-- .../torch/nn/modules/pooling/MaxPool2d.scala | 10 +++--- core/src/main/scala/torch/nn/package.scala | 1 - docs/modules.md | 2 +- docs/tutorial/buildmodel.md | 2 +- examples/src/main/scala/ImageClassifier.scala | 3 +- examples/src/main/scala/LeNet.scala | 10 +++--- .../scala/torchvision/models/resnet.scala | 16 +++++++-- 16 files changed, 76 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/torch/DType.scala b/core/src/main/scala/torch/DType.scala index 18b33c6d..d05483ba 100644 --- a/core/src/main/scala/torch/DType.scala +++ b/core/src/main/scala/torch/DType.scala @@ -209,6 +209,38 @@ private object Derive: val derive: Derive = Derive() export Derive.derive +/** Default tensor type. + * + * Defaults to float32 but can be overriden by providing providing the DType explicitly, or by + * overriding the default in the current scope through an import: + * + * Example: + * ```scala sc + * import torch.* + * + * // Default: + * nn.Linear(10, 10) // Linear[Float32] + * + * // Override explicitly: + * nn.Linear[BFloat16](10, 10) // Linear[BFloat16] + * + * // Override default: + * import Default.float64 + * nn.Linear(10, 10) // Linear[Float64] + * ``` + */ +trait Default[+D <: DType]: + def dtype: D +trait LowPriorityDefaults: + given float16: Default[Float16] = new Default[Float16] { def dtype = torch.float16 } + given bfloat16: Default[BFloat16] = new Default[BFloat16] { def dtype = torch.bfloat16 } + given float64: Default[Float64] = new Default[Float64] { def dtype = torch.float64 } + given complex32: Default[Complex32] = new Default[Complex32] { def dtype = torch.complex32 } + given complex64: Default[Complex64] = new Default[Complex64] { def dtype = torch.complex64 } + given complex128: Default[Complex128] = new Default[Complex128] { def dtype = torch.complex128 } +object Default extends LowPriorityDefaults: + given float32: Default[Float32] = new Default[Float32] { def dtype = torch.float32 } + /** DType combinations * */ type FloatNN = Float16 | Float32 | Float64 | BFloat16 @@ -452,4 +484,3 @@ transparent inline def deriveDType[T <: DType]: DType = case _: Float16 => float16 case _: Undefined => undefined case _: NumOptions => numoptions - case _ => float32 diff --git a/core/src/main/scala/torch/nn/modules/Module.scala b/core/src/main/scala/torch/nn/modules/Module.scala index ab0e3482..04ad69c6 100644 --- a/core/src/main/scala/torch/nn/modules/Module.scala +++ b/core/src/main/scala/torch/nn/modules/Module.scala @@ -70,11 +70,6 @@ abstract class Module { def namedModules: SeqMap[String, Module] = namedChildren.flatMap((name, module) => module.namedModules) - def copy(): this.type = - val clone = super.clone().asInstanceOf[Module] - clone._nativeModule = _nativeModule.clone(null) - clone.asInstanceOf[this.type] - def register[M <: Module](child: M)(using name: sourcecode.Name) = // println(s"registering ${name.value}:$child") childModules = childModules.updated(name.value, child) @@ -94,16 +89,7 @@ abstract class Module { def train(on: Boolean = true): Unit = nativeModule.train(on) def to(device: Device): this.type = - // val nativeCopy = nativeModule.clone() nativeModule.to(device.toNative, false) - // copy - // val clone: this.type = copy() - // clone.nativeModule = nativeCopy - this - - def to(dtype: DType, nonBlocking: Boolean = false): this.type = - val nativeCopy = nativeModule.clone(null) - nativeCopy.asModule.to(dtype.toScalarType, false) this def save(outputArchive: OutputArchive) = nativeModule.save(outputArchive) @@ -123,20 +109,11 @@ abstract class Module { doSummarize(0) } -/** Default tensor type for module parameters. - * - * Defaults to float32 but can be overriden by providing a given - */ -trait Default[+D <: DType]: - def dtype: D -object Default: - given f32: Default[Float32] = new Default[Float32] { def dtype = float32 } - trait HasParams[ParamType <: FloatNN | ComplexNN: Default] extends Module: override def parameters(recurse: Boolean): Seq[Tensor[ParamType]] = nativeModule.parameters(recurse).get().toSeq.map(Tensor.apply[ParamType]) override def parameters: Seq[Tensor[ParamType]] = parameters(recurse = true) - transparent inline def paramType = deriveDType[ParamType] + transparent inline def paramType: DType = summon[Default[ParamType]].dtype trait HasWeight[ParamType <: FloatNN | ComplexNN]: def weight: Tensor[ParamType] diff --git a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala index 986fbd2c..e1891076 100644 --- a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala +++ b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala @@ -14,7 +14,10 @@ * limitations under the License. */ -package torch.nn.modules.activation +package torch +package nn +package modules +package activation import org.bytedeco.pytorch import org.bytedeco.pytorch.SoftmaxImpl @@ -28,7 +31,7 @@ import torch.{DType, Tensor} * * When the input Tensor is a sparse tensor then the unspecifed values are treated as ``-inf``. */ -final class Softmax(dim: Int) extends Module: +final class Softmax[D <: DType: Default](dim: Int) extends TensorModule[D]: override val nativeModule: SoftmaxImpl = SoftmaxImpl(dim) - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala index 2a6ea11c..e106056f 100644 --- a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala +++ b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala @@ -24,7 +24,7 @@ import org.bytedeco.pytorch import sourcecode.Name import org.bytedeco.pytorch.BatchNorm2dImpl import org.bytedeco.pytorch.BatchNormOptions -import torch.nn.modules.{HasParams, HasWeight, TensorModule} +import torch.nn.modules.{HasParams, HasWeight} // format: off diff --git a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala index 32c88c7a..fed72615 100644 --- a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala +++ b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala @@ -26,7 +26,7 @@ import sourcecode.Name import torch.Tensor import torch.internal.NativeConverters.toNative import torch.nn.modules.conv.Conv2d.PaddingMode -import torch.nn.modules.{HasParams, TensorModule} +import torch.nn.modules.{HasParams} /** Applies a 2D convolution over an input signal composed of several input planes. * diff --git a/core/src/main/scala/torch/nn/modules/linear/Identity.scala b/core/src/main/scala/torch/nn/modules/linear/Identity.scala index b1d256a4..88920daa 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Identity.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Identity.scala @@ -28,7 +28,7 @@ import torch.{DType, Tensor} * * @group nn_linear */ -final class Identity(args: Any*) extends Module: +final class Identity[D <: DType: Default](args: Any*) extends TensorModule[D]: override val nativeModule: IdentityImpl = IdentityImpl() - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/linear/Linear.scala b/core/src/main/scala/torch/nn/modules/linear/Linear.scala index 65ab73f5..3d39a3c3 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Linear.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Linear.scala @@ -22,7 +22,7 @@ package linear import org.bytedeco.pytorch import org.bytedeco.pytorch.{LinearImpl, LinearOptions} import torch.Tensor -import torch.nn.modules.{HasParams, TensorModule} +import torch.nn.modules.{HasParams} /** Applies a linear transformation to the incoming data: $y = xA^T + b$ * @@ -53,7 +53,7 @@ final class Linear[ParamType <: FloatNN: Default]( bias: Boolean = true // dtype: ParamType = defaultDType[ParamType] ) extends HasParams[ParamType] - with (TensorModule[ParamType]): + with TensorModule[ParamType]: private val options = new LinearOptions(inFeatures, outFeatures) options.bias().put(bias) override private[torch] val nativeModule: LinearImpl = new LinearImpl(options) diff --git a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala index bcd588bf..d58ad75e 100644 --- a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala +++ b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala @@ -21,7 +21,6 @@ package normalization import org.bytedeco.pytorch import org.bytedeco.pytorch.{GroupNormImpl, GroupNormOptions} -import torch.nn.modules.TensorModule import torch.{DType, Tensor} /** Applies Group Normalization over a mini-batch of inputs @@ -36,12 +35,13 @@ import torch.{DType, Tensor} * a boolean value that when set to `true`, this module has learnable per-channel affine * parameters initialized to ones (for weights) and zeros (for biases) */ -final class GroupNorm[ParamType <: DType]( +final class GroupNorm[ParamType <: FloatNN | ComplexNN: Default]( numGroups: Int, numChannels: Int, eps: Double = 1e-05, affine: Boolean = true -) extends TensorModule[ParamType]: +) extends HasWeight[ParamType] + with TensorModule[ParamType]: private val options: GroupNormOptions = GroupNormOptions(numGroups, numChannels) options.eps().put(eps) options.affine().put(affine) diff --git a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala index a685cb5e..e3c6e415 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala @@ -32,7 +32,7 @@ import org.bytedeco.pytorch.LongOptional * The output is of size H x W, for any input size. The number of output features is equal to the * number of input planes. */ -final class AdaptiveAvgPool2d( +final class AdaptiveAvgPool2d[D <: BFloat16 | Float32 | Float64: Default]( outputSize: Int | Option[Int] | (Option[Int], Option[Int]) | (Int, Int) ) extends Module { @@ -49,7 +49,7 @@ final class AdaptiveAvgPool2d( nativeOutputSize.get(0) ) - def apply[D <: BFloat16 | Float32 | Float64](t: Tensor[D]): Tensor[D] = Tensor( + def apply(t: Tensor[D]): Tensor[D] = Tensor( nativeModule.forward(t.native) ) } diff --git a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala index 9f21cd70..10785797 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala @@ -23,19 +23,18 @@ import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch import org.bytedeco.pytorch.{MaxPool2dImpl, MaxPool2dOptions} import torch.internal.NativeConverters.toNative -import torch.nn.modules.{HasParams, TensorModule} +import torch.nn.modules.{HasParams} import torch.{BFloat16, Float32, Float64, Tensor} /** Applies a 2D max pooling over an input signal composed of several input planes. */ -final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default]( +final class MaxPool2d[D <: BFloat16 | Float32 | Float64: Default]( kernelSize: Int | (Int, Int), stride: Option[Int | (Int, Int)] = None, padding: Int | (Int, Int) = 0, dilation: Int | (Int, Int) = 1, // returnIndices: Boolean = false, ceilMode: Boolean = false -) extends HasParams[ParamType] - with TensorModule[ParamType]: +) extends TensorModule[D]: private val options: MaxPool2dOptions = MaxPool2dOptions(toNative(kernelSize)) stride.foreach(s => options.stride().put(toNative(s))) @@ -44,10 +43,9 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default]( options.ceil_mode().put(ceilMode) override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options) - nativeModule.to(paramType.toScalarType, false) override def toString(): String = s"MaxPool2d(kernelSize=$kernelSize, stride=$stride, padding=$padding, dilation=$dilation, ceilMode=$ceilMode)" - def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native)) + def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) // TODO forward_with_indices diff --git a/core/src/main/scala/torch/nn/package.scala b/core/src/main/scala/torch/nn/package.scala index 2d899379..ded65ebb 100644 --- a/core/src/main/scala/torch/nn/package.scala +++ b/core/src/main/scala/torch/nn/package.scala @@ -25,7 +25,6 @@ package torch package object nn { export modules.Module - export modules.Default export modules.activation.Softmax export modules.activation.ReLU diff --git a/docs/modules.md b/docs/modules.md index e16f8603..ca2e2574 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -13,7 +13,7 @@ import torch.* import torch.nn import torch.nn.functional as F -class LeNet[D <: BFloat16 | Float32: nn.Default] extends nn.Module: +class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module: val conv1 = register(nn.Conv2d(1, 6, 5)) val pool = register(nn.MaxPool2d((2, 2))) val conv2 = register(nn.Conv2d(6, 16, 5)) diff --git a/docs/tutorial/buildmodel.md b/docs/tutorial/buildmodel.md index 0d0eb6fb..5a01188d 100644 --- a/docs/tutorial/buildmodel.md +++ b/docs/tutorial/buildmodel.md @@ -78,7 +78,7 @@ We get the prediction probabilities by passing it through an instance of the ``n ```scala mdoc val X = torch.rand(Seq(1, 28, 28), device=device) val logits = model(X) -val predProbab = nn.Softmax(dim=1)(logits) +val predProbab = nn.Softmax(dim=1).apply(logits) val yPred = predProbab.argmax(1) println(s"Predicted class: $yPred") ``` diff --git a/examples/src/main/scala/ImageClassifier.scala b/examples/src/main/scala/ImageClassifier.scala index f9d2a8b3..a363db1c 100644 --- a/examples/src/main/scala/ImageClassifier.scala +++ b/examples/src/main/scala/ImageClassifier.scala @@ -16,14 +16,13 @@ //> using scala "3.3" //> using repository "sonatype-s01:snapshots" -//> using lib "dev.storch::vision:0.0-ab8d84c-SNAPSHOT" +//> using lib "dev.storch::vision:0.0-795485b-SNAPSHOT" //> using lib "me.tongfei:progressbar:0.9.5" //> using lib "com.github.alexarchambault::case-app:2.1.0-M24" //> using lib "org.scala-lang.modules::scala-parallel-collections:1.0.4" // replace with pytorch-platform-gpu if you have a CUDA capable GPU //> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.9" // enable for CUDA support -////> using lib "org.bytedeco:cuda-platform:12.1-8.9-1.5.9" ////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.9" import Commands.* diff --git a/examples/src/main/scala/LeNet.scala b/examples/src/main/scala/LeNet.scala index 8204fdaf..bcc7aa7f 100644 --- a/examples/src/main/scala/LeNet.scala +++ b/examples/src/main/scala/LeNet.scala @@ -16,18 +16,16 @@ //> using scala "3.3" //> using repository "sonatype-s01:snapshots" -//> using lib "dev.storch::vision:0.0-ab8d84c-SNAPSHOT" +//> using lib "dev.storch::vision:0.0-795485b-SNAPSHOT" // replace with pytorch-platform-gpu if you have a CUDA capable GPU //> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.9" // enable for CUDA support -////> using lib "org.bytedeco:cuda-platform:12.1-8.9-1.5.9" ////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.9" import torch.* import torch.nn.functional as F import torch.optim.Adam import org.bytedeco.pytorch.OutputArchive -import torch.nn.modules.Default import torchvision.datasets.MNIST import scala.util.Random import java.nio.file.Paths @@ -35,9 +33,11 @@ import torch.Device.CUDA import scala.util.Using import org.bytedeco.javacpp.PointerScope import torch.Device.CPU +import torch.nn.modules.HasParams + +// Define the model architecture +class LeNet[D <: BFloat16 | Float32: Default] extends HasParams[D] { -// define model architecture -class LeNet[D <: BFloat16 | Float32: Default] extends nn.Module { val conv1 = register(nn.Conv2d(1, 6, 5)) val pool = register(nn.MaxPool2d((2, 2))) val conv2 = register(nn.Conv2d(6, 16, 5)) diff --git a/vision/src/main/scala/torchvision/models/resnet.scala b/vision/src/main/scala/torchvision/models/resnet.scala index a5cd7115..3f4ca09f 100644 --- a/vision/src/main/scala/torchvision/models/resnet.scala +++ b/vision/src/main/scala/torchvision/models/resnet.scala @@ -17,7 +17,18 @@ package torchvision package models -import torch.{BFloat16, ComplexNN, DType, Float32, Float32Tensor, Float64, FloatNN, Tensor, nn} +import torch.{ + BFloat16, + ComplexNN, + DType, + Default, + Float32, + Float32Tensor, + Float64, + FloatNN, + Tensor, + nn +} import torch.nn.init.{Mode, NonLinearity, constant_, kaimingNormal_} import scala.collection.mutable @@ -28,12 +39,13 @@ import sourcecode.Name import torch.nn.modules.activation.ReLU import torch.nn.modules.conv.Conv2d import torch.nn.modules.pooling.{AdaptiveAvgPool2d, MaxPool2d} -import torch.nn.modules.{Default, HasWeight, Module, TensorModule} +import torch.nn.modules.{HasWeight, Module} import torchvision.transforms.* import scala.util.Using import com.sksamuel.scrimage.ImmutableImage import torch.Int32 +import torch.nn.modules.TensorModule /** ResNet architecture implementations *