Skip to content

Commit

Permalink
Simplify module registration after improvements in new presets
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Jul 28, 2023
1 parent 795485b commit 5f18d5a
Show file tree
Hide file tree
Showing 11 changed files with 7 additions and 62 deletions.
9 changes: 2 additions & 7 deletions core/src/main/scala/torch/nn/modules/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,10 @@ abstract class Module {
clone._nativeModule = _nativeModule.clone(null)
clone.asInstanceOf[this.type]

protected[torch] def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def register[M <: Module](child: M)(using name: sourcecode.Name) =
// println(s"registering ${name.value}:$child")
childModules = childModules.updated(name.value, child)
child.registerWithParent(this.nativeModule)
nativeModule.register_module(name.value, child.nativeModule)
child

def register[D <: DType](t: Tensor[D], requiresGrad: Boolean = true)(using
Expand All @@ -100,7 +95,7 @@ abstract class Module {

def to(device: Device): this.type =
// val nativeCopy = nativeModule.clone()
nativeModule.asModule.to(device.toNative, false)
nativeModule.to(device.toNative, false)
// copy
// val clone: this.type = copy()
// clone.nativeModule = nativeCopy
Expand Down
5 changes: 0 additions & 5 deletions core/src/main/scala/torch/nn/modules/activation/ReLU.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ final class ReLU[D <: DType: Default](inplace: Boolean = false) extends TensorMo

override protected[torch] val nativeModule: ReLUImpl = ReLUImpl()

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))

override def toString = getClass().getSimpleName()
5 changes: 0 additions & 5 deletions core/src/main/scala/torch/nn/modules/activation/Softmax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,4 @@ import torch.{DType, Tensor}
final class Softmax(dim: Int) extends Module:
override val nativeModule: SoftmaxImpl = SoftmaxImpl(dim)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,7 @@ final class BatchNorm2d[ParamType <: FloatNN | ComplexNN: Default](
options.track_running_stats().put(trackRunningStats)

override private[torch] val nativeModule: BatchNorm2dImpl = BatchNorm2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)
nativeModule.to(paramType.toScalarType, false)

// TODO weight, bias etc. are undefined if affine = false. We need to take that into account
val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight)
Expand Down
7 changes: 1 addition & 6 deletions core/src/main/scala/torch/nn/modules/conv/Conv2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,7 @@ final class Conv2d[ParamType <: FloatNN | ComplexNN: Default](
options.padding_mode().put(paddingModeNative)

override private[torch] val nativeModule: Conv2dImpl = Conv2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)
nativeModule.to(paramType.toScalarType, false)

def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native))

Expand Down
5 changes: 0 additions & 5 deletions core/src/main/scala/torch/nn/modules/flatten/Flatten.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ final class Flatten[D <: DType: Default](startDim: Int = 1, endDim: Int = -1)

override val nativeModule: FlattenImpl = FlattenImpl(options)

override def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))

override def toString = getClass().getSimpleName()
5 changes: 0 additions & 5 deletions core/src/main/scala/torch/nn/modules/linear/Identity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,4 @@ import torch.{DType, Tensor}
final class Identity(args: Any*) extends Module:
override val nativeModule: IdentityImpl = IdentityImpl()

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
7 changes: 1 addition & 6 deletions core/src/main/scala/torch/nn/modules/linear/Linear.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ final class Linear[ParamType <: FloatNN: Default](
private val options = new LinearOptions(inFeatures, outFeatures)
options.bias().put(bias)
override private[torch] val nativeModule: LinearImpl = new LinearImpl(options)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)
nativeModule.to(paramType.toScalarType, false)

def apply(input: Tensor[ParamType]): Tensor[ParamType] = Tensor(
nativeModule.forward(input.native)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ final class GroupNorm[ParamType <: DType](

override private[torch] val nativeModule: GroupNormImpl = GroupNormImpl(options)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight)
val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,10 @@ final class AdaptiveAvgPool2d(
case x: Option[Int] =>
new LongOptionalVector(x.toOptional, x.toOptional)

override private[torch] val nativeModule: AdaptiveAvgPool2dImpl = AdaptiveAvgPool2dImpl(
override protected[torch] val nativeModule: AdaptiveAvgPool2dImpl = AdaptiveAvgPool2dImpl(
nativeOutputSize.get(0)
)

override def registerWithParent[T <: pytorch.Module](parent: T)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply[D <: BFloat16 | Float32 | Float64](t: Tensor[D]): Tensor[D] = Tensor(
nativeModule.forward(t.native)
)
Expand Down
7 changes: 1 addition & 6 deletions core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ final class MaxPool2d[ParamType <: BFloat16 | Float32 | Float64: Default](
options.ceil_mode().put(ceilMode)

override private[torch] val nativeModule: MaxPool2dImpl = MaxPool2dImpl(options)
nativeModule.asModule.to(paramType.toScalarType, false)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)
nativeModule.to(paramType.toScalarType, false)

override def toString(): String =
s"MaxPool2d(kernelSize=$kernelSize, stride=$stride, padding=$padding, dilation=$dilation, ceilMode=$ceilMode)"
Expand Down

0 comments on commit 5f18d5a

Please sign in to comment.