Skip to content

Commit

Permalink
Add more modules: LogSoftmax, Tanh, BatchNorm1d, Embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
davoclavo committed Jul 6, 2023
1 parent 19abf3e commit 9a6409b
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 2 deletions.
5 changes: 4 additions & 1 deletion core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.bytedeco.pytorch.{
}

import scala.reflect.Typeable
import org.bytedeco.javacpp.LongPointer
import org.bytedeco.javacpp.{LongPointer, DoublePointer}
import org.bytedeco.pytorch.GenericDict
import org.bytedeco.pytorch.GenericDictIterator
import spire.math.Complex
Expand Down Expand Up @@ -78,6 +78,9 @@ private[torch] object NativeConverters:
case (h, w) => LongPointer(Array(h.toLong, w.toLong)*)
case (t, h, w) => LongPointer(Array(t.toLong, h.toLong, w.toLong)*)

given doubleToDoublePointer: Conversion[Double, DoublePointer] = (input: Double) =>
DoublePointer(Array(input)*)

extension (x: ScalaType)
def toScalar: pytorch.Scalar = x match
case x: Boolean => pytorch.Scalar(if x then 1: Byte else 0: Byte)
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/torch/nn/modules/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,7 @@ trait HasWeight[ParamType <: FloatNN | ComplexNN]:
/** Transforms a single tensor into another one of the same type. */
trait TensorModule[D <: DType] extends Module with (Tensor[D] => Tensor[D]):
override def toString(): String = "TensorModule"

trait TensorModuleBase[D <: DType, D2 <: DType] extends Module with (Tensor[D] => Tensor[D2]) {
override def toString() = "TensorModuleBase"
}
40 changes: 40 additions & 0 deletions core/src/main/scala/torch/nn/modules/activation/LogSoftmax.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package nn
package modules
package activation

import org.bytedeco.pytorch
import org.bytedeco.pytorch.LogSoftmaxImpl
import torch.nn.modules.Module
import torch.{DType, Tensor}

/** Applies the log(Softmax(x)) function to an n-dimensional input Tensor. The LogSoftmax
* formulation can be simplified as:
*
* TODO LaTeX
*/
final class LogSoftmax(dim: Int) extends Module:
override val nativeModule: LogSoftmaxImpl = LogSoftmaxImpl(dim)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
42 changes: 42 additions & 0 deletions core/src/main/scala/torch/nn/modules/activation/Tanh.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package nn
package modules
package activation

import org.bytedeco.pytorch
import org.bytedeco.pytorch.TanhImpl
import torch.nn.modules.Module
import torch.{DType, Tensor}

/** Applies the Hyperbolic Tangent (Tanh) function element-wise. Tanh is defined as::
*
* TODO LaTeX
*/
final class Tanh[D <: DType: Default]() extends TensorModule[D]:

override protected[torch] val nativeModule: TanhImpl = new TanhImpl()

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))

override def toString = getClass().getSimpleName()
135 changes: 135 additions & 0 deletions core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package nn
package modules
package batchnorm

import org.bytedeco.javacpp.LongPointer
import org.bytedeco.pytorch
import sourcecode.Name
import org.bytedeco.pytorch.BatchNorm1dImpl
import org.bytedeco.pytorch.BatchNormOptions
import torch.nn.modules.{HasParams, HasWeight, TensorModule}

// format: off
/** Applies Batch Normalization over a 2D or 3D input as described in the paper
[Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167) .
$$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$
The mean and standard-deviation are calculated per-dimension over
the mini-batches and $\gamma$ and $\beta$ are learnable parameter vectors
of size [C]{.title-ref} (where [C]{.title-ref} is the number of features or channels of the input). By default, the
elements of $\gamma$ are set to 1 and the elements of $\beta$ are set to 0. The
standard-deviation is calculated via the biased estimator, equivalent to [torch.var(input, unbiased=False)]{.title-ref}.
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default `momentum`{.interpreted-text role="attr"}
of 0.1.
If `track_running_stats`{.interpreted-text role="attr"} is set to `False`, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.
::: note
::: title
Note
:::
This `momentum`{.interpreted-text role="attr"} argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
$\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$,
where $\hat{x}$ is the estimated statistic and $x_t$ is the
new observed value.
:::
Because the Batch Normalization is done over the [C]{.title-ref} dimension, computing statistics
on [(N, L)]{.title-ref} slices, it\'s common terminology to call this Temporal Batch Normalization.
Args:
: num_features: number of features or channels $C$ of the input
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to `None` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to `True`, this module has
learnable affine parameters. Default: `True`
track_running_stats: a boolean value that when set to `True`, this
module tracks the running mean and variance, and when set to `False`,
this module does not track such statistics, and initializes statistics
buffers `running_mean`{.interpreted-text role="attr"} and `running_var`{.interpreted-text role="attr"} as `None`.
When these buffers are `None`, this module always uses batch statistics.
in both training and eval modes. Default: `True`
Shape:
: - Input: $(N, C)$ or $(N, C, L)$, where $N$ is the batch size,
$C$ is the number of features or channels, and $L$ is the sequence length
- Output: $(N, C)$ or $(N, C, L)$ (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = nn.BatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm1d(100, affine=False)
>>> input = torch.randn(20, 100)
>>> output = m(input)
*
* @group nn_conv
*
* TODO use dtype
*/
// format: on
final class BatchNorm1d[ParamType <: FloatNN | ComplexNN: Default](
numFeatures: Int,
eps: Double = 1e-05,
momentum: Double = 0.1,
affine: Boolean = true,
trackRunningStats: Boolean = true
) extends HasParams[ParamType]
with HasWeight[ParamType]
with TensorModule[ParamType]:

private val options = new BatchNormOptions(numFeatures)
options.eps().put(eps)
options.momentum().put(momentum)
options.affine().put(affine)
options.track_running_stats().put(trackRunningStats)

override private[torch] val nativeModule: BatchNorm1dImpl = BatchNorm1dImpl(options)
nativeModule.asModule.to(paramType.toScalarType)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

// TODO weight, bias etc. are undefined if affine = false. We need to take that into account
val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight)
val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias)
// TODO running_mean, running_var, num_batches_tracked

def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native))

override def toString(): String = s"${getClass().getSimpleName()}(numFeatures=$numFeatures)"
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import torch.nn.modules.{HasParams, HasWeight, TensorModule}


// format: off
/** Applies Batch Normalization over a 2D or 3D input as described in the paper
/** Applies Batch Normalization over a 4D input as described in the paper
[Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167) .
Expand Down
61 changes: 61 additions & 0 deletions core/src/main/scala/torch/nn/modules/sparse/Embedding.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package nn
package modules
package sparse

import org.bytedeco.javacpp.LongPointer
import org.bytedeco.pytorch
import sourcecode.Name
import org.bytedeco.pytorch.EmbeddingImpl
import org.bytedeco.pytorch.EmbeddingOptions
import torch.nn.modules.{HasParams, HasWeight, TensorModule}
import torch.internal.NativeConverters.{toNative, doubleToDoublePointer}

final class Embedding[ParamType <: FloatNN | ComplexNN: Default](
numEmbeddings: Int,
embeddingDim: Int,
paddingIdx: Option[Int] = None,
maxNorm: Option[Double] = None,
normType: Option[Double] = Some(2.0),
scaleGradByFreq: Boolean = false,
sparse: Boolean = false
) extends HasParams[ParamType]
with HasWeight[ParamType]
with TensorModuleBase[Int64, ParamType]:

private val options = new EmbeddingOptions(numEmbeddings.toLong, embeddingDim.toLong)
paddingIdx.foreach(p => options.padding_idx().put(toNative(p)))
maxNorm.foreach(m => options.max_norm().put(m))
normType.foreach(n => options.norm_type().put(n))
options.scale_grad_by_freq().put(scaleGradByFreq)
options.sparse().put(sparse)

override val nativeModule: EmbeddingImpl = EmbeddingImpl(options)
nativeModule.asModule.to(paramType.toScalarType)

override def registerWithParent[M <: pytorch.Module](parent: M)(using
name: sourcecode.Name
): Unit =
parent.register_module(name.value, nativeModule)

val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight)

def apply(t: Tensor[Int64]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native))

override def toString(): String = s"${getClass().getSimpleName()}(numEmbeddings=$numEmbeddings)"
6 changes: 6 additions & 0 deletions core/src/main/scala/torch/nn/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ package object nn {
export modules.Default

export modules.activation.Softmax
export modules.activation.LogSoftmax
export modules.activation.ReLU
export modules.activation.Tanh
export modules.batchnorm.BatchNorm1d
export modules.batchnorm.BatchNorm2d
export modules.container.Sequential
export modules.conv.Conv2d
Expand All @@ -41,4 +44,7 @@ package object nn {
export modules.normalization.GroupNorm
export modules.pooling.AdaptiveAvgPool2d
export modules.pooling.MaxPool2d
export modules.sparse.Embedding

export loss.CrossEntropyLoss
}

0 comments on commit 9a6409b

Please sign in to comment.