-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add more modules: LogSoftmax, Tanh, BatchNorm1d, Embedding
- Loading branch information
Showing
8 changed files
with
293 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 40 additions & 0 deletions
40
core/src/main/scala/torch/nn/modules/activation/LogSoftmax.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Copyright 2022 storch.dev | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package torch | ||
package nn | ||
package modules | ||
package activation | ||
|
||
import org.bytedeco.pytorch | ||
import org.bytedeco.pytorch.LogSoftmaxImpl | ||
import torch.nn.modules.Module | ||
import torch.{DType, Tensor} | ||
|
||
/** Applies the log(Softmax(x)) function to an n-dimensional input Tensor. The LogSoftmax | ||
* formulation can be simplified as: | ||
* | ||
* TODO LaTeX | ||
*/ | ||
final class LogSoftmax(dim: Int) extends Module: | ||
override val nativeModule: LogSoftmaxImpl = LogSoftmaxImpl(dim) | ||
|
||
override def registerWithParent[M <: pytorch.Module](parent: M)(using | ||
name: sourcecode.Name | ||
): Unit = | ||
parent.register_module(name.value, nativeModule) | ||
|
||
def apply[D <: DType](t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) |
42 changes: 42 additions & 0 deletions
42
core/src/main/scala/torch/nn/modules/activation/Tanh.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* | ||
* Copyright 2022 storch.dev | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package torch | ||
package nn | ||
package modules | ||
package activation | ||
|
||
import org.bytedeco.pytorch | ||
import org.bytedeco.pytorch.TanhImpl | ||
import torch.nn.modules.Module | ||
import torch.{DType, Tensor} | ||
|
||
/** Applies the Hyperbolic Tangent (Tanh) function element-wise. Tanh is defined as:: | ||
* | ||
* TODO LaTeX | ||
*/ | ||
final class Tanh[D <: DType: Default]() extends TensorModule[D]: | ||
|
||
override protected[torch] val nativeModule: TanhImpl = new TanhImpl() | ||
|
||
override def registerWithParent[M <: pytorch.Module](parent: M)(using | ||
name: sourcecode.Name | ||
): Unit = | ||
parent.register_module(name.value, nativeModule) | ||
|
||
def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native)) | ||
|
||
override def toString = getClass().getSimpleName() |
135 changes: 135 additions & 0 deletions
135
core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
/* | ||
* Copyright 2022 storch.dev | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package torch | ||
package nn | ||
package modules | ||
package batchnorm | ||
|
||
import org.bytedeco.javacpp.LongPointer | ||
import org.bytedeco.pytorch | ||
import sourcecode.Name | ||
import org.bytedeco.pytorch.BatchNorm1dImpl | ||
import org.bytedeco.pytorch.BatchNormOptions | ||
import torch.nn.modules.{HasParams, HasWeight, TensorModule} | ||
|
||
// format: off | ||
/** Applies Batch Normalization over a 2D or 3D input as described in the paper | ||
[Batch Normalization: Accelerating Deep Network Training by Reducing | ||
Internal Covariate Shift](https://arxiv.org/abs/1502.03167) . | ||
$$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$ | ||
The mean and standard-deviation are calculated per-dimension over | ||
the mini-batches and $\gamma$ and $\beta$ are learnable parameter vectors | ||
of size [C]{.title-ref} (where [C]{.title-ref} is the number of features or channels of the input). By default, the | ||
elements of $\gamma$ are set to 1 and the elements of $\beta$ are set to 0. The | ||
standard-deviation is calculated via the biased estimator, equivalent to [torch.var(input, unbiased=False)]{.title-ref}. | ||
Also by default, during training this layer keeps running estimates of its | ||
computed mean and variance, which are then used for normalization during | ||
evaluation. The running estimates are kept with a default `momentum`{.interpreted-text role="attr"} | ||
of 0.1. | ||
If `track_running_stats`{.interpreted-text role="attr"} is set to `False`, this layer then does not | ||
keep running estimates, and batch statistics are instead used during | ||
evaluation time as well. | ||
::: note | ||
::: title | ||
Note | ||
::: | ||
This `momentum`{.interpreted-text role="attr"} argument is different from one used in optimizer | ||
classes and the conventional notion of momentum. Mathematically, the | ||
update rule for running statistics here is | ||
$\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t$, | ||
where $\hat{x}$ is the estimated statistic and $x_t$ is the | ||
new observed value. | ||
::: | ||
Because the Batch Normalization is done over the [C]{.title-ref} dimension, computing statistics | ||
on [(N, L)]{.title-ref} slices, it\'s common terminology to call this Temporal Batch Normalization. | ||
Args: | ||
: num_features: number of features or channels $C$ of the input | ||
eps: a value added to the denominator for numerical stability. | ||
Default: 1e-5 | ||
momentum: the value used for the running_mean and running_var | ||
computation. Can be set to `None` for cumulative moving average | ||
(i.e. simple average). Default: 0.1 | ||
affine: a boolean value that when set to `True`, this module has | ||
learnable affine parameters. Default: `True` | ||
track_running_stats: a boolean value that when set to `True`, this | ||
module tracks the running mean and variance, and when set to `False`, | ||
this module does not track such statistics, and initializes statistics | ||
buffers `running_mean`{.interpreted-text role="attr"} and `running_var`{.interpreted-text role="attr"} as `None`. | ||
When these buffers are `None`, this module always uses batch statistics. | ||
in both training and eval modes. Default: `True` | ||
Shape: | ||
: - Input: $(N, C)$ or $(N, C, L)$, where $N$ is the batch size, | ||
$C$ is the number of features or channels, and $L$ is the sequence length | ||
- Output: $(N, C)$ or $(N, C, L)$ (same shape as input) | ||
Examples: | ||
>>> # With Learnable Parameters | ||
>>> m = nn.BatchNorm1d(100) | ||
>>> # Without Learnable Parameters | ||
>>> m = nn.BatchNorm1d(100, affine=False) | ||
>>> input = torch.randn(20, 100) | ||
>>> output = m(input) | ||
* | ||
* @group nn_conv | ||
* | ||
* TODO use dtype | ||
*/ | ||
// format: on | ||
final class BatchNorm1d[ParamType <: FloatNN | ComplexNN: Default]( | ||
numFeatures: Int, | ||
eps: Double = 1e-05, | ||
momentum: Double = 0.1, | ||
affine: Boolean = true, | ||
trackRunningStats: Boolean = true | ||
) extends HasParams[ParamType] | ||
with HasWeight[ParamType] | ||
with TensorModule[ParamType]: | ||
|
||
private val options = new BatchNormOptions(numFeatures) | ||
options.eps().put(eps) | ||
options.momentum().put(momentum) | ||
options.affine().put(affine) | ||
options.track_running_stats().put(trackRunningStats) | ||
|
||
override private[torch] val nativeModule: BatchNorm1dImpl = BatchNorm1dImpl(options) | ||
nativeModule.asModule.to(paramType.toScalarType) | ||
|
||
override def registerWithParent[M <: pytorch.Module](parent: M)(using | ||
name: sourcecode.Name | ||
): Unit = | ||
parent.register_module(name.value, nativeModule) | ||
|
||
// TODO weight, bias etc. are undefined if affine = false. We need to take that into account | ||
val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) | ||
val bias: Tensor[ParamType] = Tensor[ParamType](nativeModule.bias) | ||
// TODO running_mean, running_var, num_batches_tracked | ||
|
||
def apply(t: Tensor[ParamType]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native)) | ||
|
||
override def toString(): String = s"${getClass().getSimpleName()}(numFeatures=$numFeatures)" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
core/src/main/scala/torch/nn/modules/sparse/Embedding.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* Copyright 2022 storch.dev | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package torch | ||
package nn | ||
package modules | ||
package sparse | ||
|
||
import org.bytedeco.javacpp.LongPointer | ||
import org.bytedeco.pytorch | ||
import sourcecode.Name | ||
import org.bytedeco.pytorch.EmbeddingImpl | ||
import org.bytedeco.pytorch.EmbeddingOptions | ||
import torch.nn.modules.{HasParams, HasWeight, TensorModule} | ||
import torch.internal.NativeConverters.{toNative, doubleToDoublePointer} | ||
|
||
final class Embedding[ParamType <: FloatNN | ComplexNN: Default]( | ||
numEmbeddings: Int, | ||
embeddingDim: Int, | ||
paddingIdx: Option[Int] = None, | ||
maxNorm: Option[Double] = None, | ||
normType: Option[Double] = Some(2.0), | ||
scaleGradByFreq: Boolean = false, | ||
sparse: Boolean = false | ||
) extends HasParams[ParamType] | ||
with HasWeight[ParamType] | ||
with TensorModuleBase[Int64, ParamType]: | ||
|
||
private val options = new EmbeddingOptions(numEmbeddings.toLong, embeddingDim.toLong) | ||
paddingIdx.foreach(p => options.padding_idx().put(toNative(p))) | ||
maxNorm.foreach(m => options.max_norm().put(m)) | ||
normType.foreach(n => options.norm_type().put(n)) | ||
options.scale_grad_by_freq().put(scaleGradByFreq) | ||
options.sparse().put(sparse) | ||
|
||
override val nativeModule: EmbeddingImpl = EmbeddingImpl(options) | ||
nativeModule.asModule.to(paramType.toScalarType) | ||
|
||
override def registerWithParent[M <: pytorch.Module](parent: M)(using | ||
name: sourcecode.Name | ||
): Unit = | ||
parent.register_module(name.value, nativeModule) | ||
|
||
val weight: Tensor[ParamType] = Tensor[ParamType](nativeModule.weight) | ||
|
||
def apply(t: Tensor[Int64]): Tensor[ParamType] = Tensor(nativeModule.forward(t.native)) | ||
|
||
override def toString(): String = s"${getClass().getSimpleName()}(numEmbeddings=$numEmbeddings)" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters