From 5f18d5a471886d6684142dba7874d7914ab880eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Fri, 28 Jul 2023 21:28:10 +0200 Subject: [PATCH] Simplify module registration after improvements in new presets --- core/src/main/scala/torch/nn/modules/Module.scala | 9 ++------- .../main/scala/torch/nn/modules/activation/ReLU.scala | 5 ----- .../main/scala/torch/nn/modules/activation/Softmax.scala | 5 ----- .../scala/torch/nn/modules/batchnorm/BatchNorm2d.scala | 7 +------ core/src/main/scala/torch/nn/modules/conv/Conv2d.scala | 7 +------ .../main/scala/torch/nn/modules/flatten/Flatten.scala | 5 ----- .../main/scala/torch/nn/modules/linear/Identity.scala | 5 ----- core/src/main/scala/torch/nn/modules/linear/Linear.scala | 7 +------ .../scala/torch/nn/modules/normalization/GroupNorm.scala | 5 ----- .../torch/nn/modules/pooling/AdaptiveAvgPool2d.scala | 7 +------ .../main/scala/torch/nn/modules/pooling/MaxPool2d.scala | 7 +------ 11 files changed, 7 insertions(+), 62 deletions(-) diff --git a/core/src/main/scala/torch/nn/modules/Module.scala b/core/src/main/scala/torch/nn/modules/Module.scala index 4db7b1e3..ab0e3482 100644 --- a/core/src/main/scala/torch/nn/modules/Module.scala +++ b/core/src/main/scala/torch/nn/modules/Module.scala @@ -75,15 +75,10 @@ abstract class Module { clone._nativeModule = _nativeModule.clone(null) clone.asInstanceOf[this.type] - protected[torch] def registerWithParent[T <: pytorch.Module](parent: T)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def register[M <: Module](child: M)(using name: sourcecode.Name) = // println(s"registering ${name.value}:$child") childModules = childModules.updated(name.value, child) - child.registerWithParent(this.nativeModule) + nativeModule.register_module(name.value, child.nativeModule) child def register[D <: DType](t: Tensor[D], requiresGrad: Boolean = true)(using @@ -100,7 +95,7 @@ abstract class Module { def to(device: Device): this.type = // val nativeCopy = nativeModule.clone() - nativeModule.asModule.to(device.toNative, false) + nativeModule.to(device.toNative, false) // copy // val clone: this.type = copy() // clone.nativeModule = nativeCopy diff --git a/core/src/main/scala/torch/nn/modules/activation/ReLU.scala b/core/src/main/scala/torch/nn/modules/activation/ReLU.scala index 1e299606..690d633e 100644 --- a/core/src/main/scala/torch/nn/modules/activation/ReLU.scala +++ b/core/src/main/scala/torch/nn/modules/activation/ReLU.scala @@ -34,11 +34,6 @@ final class ReLU[D <: DType: Default](inplace: Boolean = false) extends TensorMo override protected[torch] val nativeModule: ReLUImpl = ReLUImpl() - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) override def toString = getClass().getSimpleName() diff --git a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala index 1da2211b..986fbd2c 100644 --- a/core/src/main/scala/torch/nn/modules/activation/Softmax.scala +++ b/core/src/main/scala/torch/nn/modules/activation/Softmax.scala @@ -31,9 +31,4 @@ import torch.{DType, Tensor} final class Softmax(dim: Int) extends Module: override val nativeModule: SoftmaxImpl = SoftmaxImpl(dim) - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala index 9c25c128..2a6ea11c 100644 --- a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala +++ b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala @@ -119,12 +119,7 @@ final class BatchNorm2d[ParamType <: FloatNN | ComplexNN: Default]( options.track_running_stats().put(trackRunningStats) override private[torch] val nativeModule: BatchNorm2dImpl = BatchNorm2dImpl(options) - nativeModule.asModule.to(paramType.toScalarType, false) - - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) + nativeModule.to(paramType.toScalarType, false) // TODO weight, bias etc. are undefined if affine = false. We need to take that into account val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) diff --git a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala index f419c274..32c88c7a 100644 --- a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala +++ b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala @@ -59,12 +59,7 @@ final class Conv2d[ParamType <: FloatNN | ComplexNN: Default]( options.padding_mode().put(paddingModeNative) override private[torch] val nativeModule: Conv2dImpl = Conv2dImpl(options) - nativeModule.asModule.to(paramType.toScalarType, false) - - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) + nativeModule.to(paramType.toScalarType, false) def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/flatten/Flatten.scala b/core/src/main/scala/torch/nn/modules/flatten/Flatten.scala index 6c8b5844..34fc7b9c 100644 --- a/core/src/main/scala/torch/nn/modules/flatten/Flatten.scala +++ b/core/src/main/scala/torch/nn/modules/flatten/Flatten.scala @@ -59,11 +59,6 @@ final class Flatten[D <: DType: Default](startDim: Int = 1, endDim: Int = -1) override val nativeModule: FlattenImpl = FlattenImpl(options) - override def registerWithParent[T <: pytorch.Module](parent: T)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) override def toString = getClass().getSimpleName() diff --git a/core/src/main/scala/torch/nn/modules/linear/Identity.scala b/core/src/main/scala/torch/nn/modules/linear/Identity.scala index 2beef814..b1d256a4 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Identity.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Identity.scala @@ -31,9 +31,4 @@ import torch.{DType, Tensor} final class Identity(args: Any*) extends Module: override val nativeModule: IdentityImpl = IdentityImpl() - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/nn/modules/linear/Linear.scala b/core/src/main/scala/torch/nn/modules/linear/Linear.scala index 26ff8cc5..65ab73f5 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Linear.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Linear.scala @@ -57,12 +57,7 @@ final class Linear[ParamType <: FloatNN: Default]( private val options = new LinearOptions(inFeatures, outFeatures) options.bias().put(bias) override private[torch] val nativeModule: LinearImpl = new LinearImpl(options) - nativeModule.asModule.to(paramType.toScalarType, false) - - override def registerWithParent[T <: pytorch.Module](parent: T)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) + nativeModule.to(paramType.toScalarType, false) def apply(input: Tensor[ParamType]): Tensor[ParamType] = Tensor( nativeModule.forward(input.native) diff --git a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala index fd1bc268..bcd588bf 100644 --- a/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala +++ b/core/src/main/scala/torch/nn/modules/normalization/GroupNorm.scala @@ -48,11 +48,6 @@ final class GroupNorm[ParamType <: DType]( override private[torch] val nativeModule: GroupNormImpl = GroupNormImpl(options) - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias) diff --git a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala index 04bf4ca8..a685cb5e 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala @@ -45,15 +45,10 @@ final class AdaptiveAvgPool2d( case x: Option[Int] => new LongOptionalVector(x.toOptional, x.toOptional) - override private[torch] val nativeModule: AdaptiveAvgPool2dImpl = AdaptiveAvgPool2dImpl( + override protected[torch] val nativeModule: AdaptiveAvgPool2dImpl = AdaptiveAvgPool2dImpl( nativeOutputSize.get(0) ) - override def registerWithParent[T <: pytorch.Module](parent: T)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) - def apply[D <: BFloat16 | Float32 | Float64](t: Tensor[D]): Tensor[D] = Tensor( nativeModule.forward(t.native) ) diff --git a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala index abed0ad7..9f21cd70 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala @@ -44,12 +44,7 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default]( options.ceil_mode().put(ceilMode) override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options) - nativeModule.asModule.to(paramType.toScalarType, false) - - override def registerWithParent[M <: pytorch.Module](parent: M)(using - name: sourcecode.Name - ): Unit = - parent.register_module(name.value, nativeModule) + nativeModule.to(paramType.toScalarType, false) override def toString(): String = s"MaxPool2d(kernelSize=$kernelSize, stride=$stride, padding=$padding, dilation=$dilation, ceilMode=$ceilMode)"