From 2935def607b6c579761112a9625560b7102ef545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Sun, 21 May 2023 22:45:07 +0200 Subject: [PATCH 1/5] Add more ops and todos for missing ops --- core/src/main/scala/torch/torch.scala | 167 ++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/core/src/main/scala/torch/torch.scala b/core/src/main/scala/torch/torch.scala index 97c61516..3a2a28ab 100644 --- a/core/src/main/scala/torch/torch.scala +++ b/core/src/main/scala/torch/torch.scala @@ -480,6 +480,34 @@ def cat[D <: DType](tensors: Seq[Tensor[D]], dim: Int = 0): Tensor[D] = Tensor( torchNative.cat(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)), dim.toLong) ) +// TODO dsplit +// TODO column_stack +// TODO dstack +// TODO gather +// TODO hsplit +// TODO hstack +// TODO index_add +// TODO index_copy +// TODO index_reduce +// TODO index_select +// TODO masked_select +// TODO movedim +// TODO moveaxis +// TODO narrow +// TODO narrow_copy +// TODO nonzero +// TODO permute +// TODO reshape +// TODO select +// TODO scatter +// TODO diagonal_scatter +// TODO select_scatter +// TODO slice_scatter +// TODO scatter_add +// TODO scatter_reduce +// TODO split +// TODO squeeze + /** Concatenates a sequence of tensors along a new dimension. * * All tensors need to be of the same size. @@ -488,10 +516,149 @@ def stack[D <: DType](tensors: Seq[Tensor[D]], dim: Int = 0): Tensor[D] = Tensor torchNative.stack(new TensorArrayRef(new TensorVector(tensors.map(_.native)*)), dim) ) +// TODO swapaxes +// TODO swapdims +// TODO t +// TODO take +// TODO take_along_dim +// TODO tensor_split +// TODO tile +// TODO transpose +// TODO unbind +// TODO unsqueeze +// TODO vsplit +// TODO vstack +// TODO where + // End Indexing, Slicing, Joining, Mutating Ops // Math operations +// Pointwise Ops + +/** Computes the absolute value of each element in `input`. */ +def abs[D <: DType](input: Tensor[D]) = Tensor(torchNative.abs(input.native)) + +/** Computes the inverse cosine of each element in `input`. */ +def acos[D <: DType](input: Tensor[D]) = Tensor(torchNative.acos(input.native)) + +/** Returns a new tensor with the inverse hyperbolic cosine of the elements of `input` . */ +def acosh[D <: DType](input: Tensor[D]) = Tensor(torchNative.acosh(input.native)) + +/** Adds `other` to `input`. */ +def add[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Promoted[D, D2]] = + Tensor(torchNative.add(input.native, other.native)) + +/** Adds `other` to `input`. */ +def add[D <: DType, S <: ScalaType]( + input: Tensor[D], + other: ScalaType +): Tensor[Promoted[D, ScalaToDType[S]]] = + Tensor(torchNative.add(input.native, toScalar(other))) + +// TODO addcdiv +// TODO addcmul +// TODO angle +// TODO asin +// TODO atan +// TODO atanh +// TODO atan2 +// TODO bitwise_not +// TODO bitwise_and +// TODO bitwise_or +// TODO bitwise_xor +// TODO bitwise_left_shift +// TODO bitwise_right_shift +// TODO ceil +// TODO clamp +// TODO conj_physical +// TODO copysign +// TODO cos +// TODO cosh +// TODO deg2rad +// TODO div +// TODO exp +// TODO exp2 +// TODO expm1 +// TODO fake_quantize_per_channel_affine +// TODO fake_quantize_per_tensor_affine +// TODO float_power +// TODO floor +// TODO floor_divide +// TODO fmod +// TODO frac +// TODO frexp +// TODO gradient +// TODO imag +// TODO ldexp +// TODO lerp +// TODO lgamma +// TODO log + +/** Returns a new tensor with the natural logarithm of the elements of input. */ +def log[D <: DType](input: Tensor[D]) = Tensor(torchNative.log(input.native)) + +// TODO log10 +// TODO log1p + +/** Returns a new tensor with the logarithm to the base 2 of the elements of `input`. */ +def log2[D <: DType](input: Tensor[D]) = Tensor(torchNative.log2(input.native)) + +// TODO logaddexp +// TODO logaddexp2 +// TODO logical_and +// TODO locigal_not +// TODO logical_or +// TODO logical_xor +// TODO logit +// TODO hypot +// TODO i0 +// TODO igamma +// TODO igammac + +def mul[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Promoted[D, D2]] = + Tensor(torchNative.mul(input.native, other.native)) + +// TODO mlvgamma +// TODO nan_to_num +// TODO neg +// TODO nextafter +// TODO polygamma +// TODO positive +// TODO pow +// TODO quantized_batch_norm +// TODO quantized_max_pool1d +// TODO quantized_max_pool2d +// TODO rad2deg +// TODO real +// TODO reciprocal +// TODO remainder +// TODO round +// TODO rsqrt +// TODO sigmoid +// TODO sign +// TODO sgn +// TODO signbit + +def sin[D <: DType](input: Tensor[D]) = Tensor(torchNative.sin(input.native)) + +// TODO sinc +// TODO sinh +// TODO softmax + +export torch.nn.functional.softmax + +// TODO sqrt +// TODO square +// TODO sub +// TODO tan +// TODO tanh +// TODO true_divide +// TODO trunc +// TODO xlogy + +// End Pointwise Ops + // Comparison Ops def allclose( From 8ee8b81b0ab14fc9760aabea984d41aae6278f42 Mon Sep 17 00:00:00 2001 From: David Gomez-Urquiza Date: Sun, 21 May 2023 15:41:13 -0600 Subject: [PATCH 2/5] Add contributing guide, reproducible build files, and linting git hook - CONTRIBUTING guide - Add pre-push git hooks - Devenv required files --- .envrc | 4 ++ .gitignore | 5 ++ CONTRIBUTING.md | 53 +++++++++++++++ devenv.lock | 138 ++++++++++++++++++++++++++++++++++++++ devenv.nix | 18 +++++ devenv.yaml | 4 ++ git-hooks/pre-push-checks | 7 ++ 7 files changed, 229 insertions(+) create mode 100644 .envrc create mode 100644 CONTRIBUTING.md create mode 100644 devenv.lock create mode 100644 devenv.nix create mode 100644 devenv.yaml create mode 100755 git-hooks/pre-push-checks diff --git a/.envrc b/.envrc new file mode 100644 index 00000000..5dae8de0 --- /dev/null +++ b/.envrc @@ -0,0 +1,4 @@ +watch_file devenv.nix +watch_file devenv.yaml +watch_file devenv.lock +eval "$(devenv print-dev-env)" \ No newline at end of file diff --git a/.gitignore b/.gitignore index f38f1b44..c1a44ddf 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,8 @@ metals.sbt *.worksheet.sc /data/ .scala-build/ + +# Devenv +.devenv* +devenv.local.nix + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..0252689a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,53 @@ +# Contributing to Storch + + +## Install dependencies via nix + devenv + +1. Install [nix](https://nixos.org) package manager + +```bash +sh <(curl -L https://nixos.org/nix/install) --daemon +``` + +For more info, see https://nixos.org/download.html + +2. Install [devenv](https://devenv.sh) + +```bash +nix profile install --accept-flake-config github:cachix/devenv/latest +``` + +For more info, see: https://devenv.sh/getting-started/#installation + +3. [Optionally] Install [direnv](https://direnv.net/) + +This will load the specific environment variables upon `cd` into the storch folder + +```bash +nix profile install 'nixpkgs#direnv' +``` + +4. Load environment + +If you did not install direnv, run the following in the `storch` root folder: + +```bash +devenv shell +``` + +If you installed direnv, just `cd` into storch + + +## Linting + +### Manually run headerCrate + scalafmt on all files + +```bash +sbt 'headerCreateAll ; scalafmtAll' +``` + +### Add useful git pre-push linting checks + +```bash +cp git-hooks/pre-push-checks .git/hooks/ && chmod +x git-hooks/pre-push-checks +``` diff --git a/devenv.lock b/devenv.lock new file mode 100644 index 00000000..3c8bafe6 --- /dev/null +++ b/devenv.lock @@ -0,0 +1,138 @@ +{ + "nodes": { + "devenv": { + "locked": { + "dir": "src/modules", + "lastModified": 1677225427, + "narHash": "sha256-+M4LGzGVAhkM0HOy2lexDUBwlLnv8UUX32utlSFaMc4=", + "owner": "cachix", + "repo": "devenv", + "rev": "1aa3dbbf745f50e77aadc016d124fed3ca2dc9be", + "type": "github" + }, + "original": { + "dir": "src/modules", + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1673956053, + "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "locked": { + "lastModified": 1667395993, + "narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1660459072, + "narHash": "sha256-8DFJjXG8zqoONA1vXtgeKXy68KdJL5UaXR8NtVMUbx8=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "a20de23b925fd8264fd7fad6454652e142fd7f73", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1677352614, + "narHash": "sha256-VYo1cSiCHDXZrHO8pb0c9EGob7C75lCPx1jBMi9UAlU=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "bf592ea571b11dfee17a74d022f0b481ca5f1319", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-stable": { + "locked": { + "lastModified": 1673800717, + "narHash": "sha256-SFHraUqLSu5cC6IxTprex/nTsI81ZQAtDvlBvGDWfnA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2f9fd351ec37f5d479556cd48be4ca340da59b8f", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-22.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "pre-commit-hooks": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "gitignore": "gitignore", + "nixpkgs": [ + "nixpkgs" + ], + "nixpkgs-stable": "nixpkgs-stable" + }, + "locked": { + "lastModified": 1677160285, + "narHash": "sha256-tBzpCjMP+P3Y3nKLYvdBkXBg3KvTMo3gvi8tLQaqXVY=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "2bd861ab81469428d9c823ef72c4bb08372dd2c4", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "nixpkgs": "nixpkgs", + "pre-commit-hooks": "pre-commit-hooks" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/devenv.nix b/devenv.nix new file mode 100644 index 00000000..4a4454bc --- /dev/null +++ b/devenv.nix @@ -0,0 +1,18 @@ +{ pkgs, inputs, ... }: + +let + packages = if pkgs.stdenv.isDarwin + then inputs.nixpkgs.legacyPackages.x86_64-darwin + else pkgs; +in +{ + packages = with packages; [ + sbt + ]; + + scripts.hello.exec = "echo ---STORCH---"; + + enterShell = '' + hello + ''; +} diff --git a/devenv.yaml b/devenv.yaml new file mode 100644 index 00000000..8de00394 --- /dev/null +++ b/devenv.yaml @@ -0,0 +1,4 @@ +inputs: + nixpkgs: + url: github:NixOS/nixpkgs/nixpkgs-unstable +# can't point to the local modules here as it's used as a template \ No newline at end of file diff --git a/git-hooks/pre-push-checks b/git-hooks/pre-push-checks new file mode 100755 index 00000000..e20a1d6b --- /dev/null +++ b/git-hooks/pre-push-checks @@ -0,0 +1,7 @@ +#!/bin/sh + +# Go to the repository root +cd "${GIT_DIR}/.." + +# Run the sbt linting checks +sbt 'headerCheckAll ; scalafmtCheckAll ; scalafmtSbtCheck' From 2a562d2b1b67e35404830aaad306a57f57a723d7 Mon Sep 17 00:00:00 2001 From: David Gomez-Urquiza Date: Tue, 23 May 2023 23:45:11 -0600 Subject: [PATCH 3/5] Implement torch Pointwise Ops, add property+unit testing helper - Implement torch Pointwise Ops - https://pytorch.org/docs/stable/torch.html#pointwise-ops - Add test helper to perform property + unit tests - Add useful union type declarations --- core/src/main/scala/torch/DType.scala | 48 +- .../torch/internal/NativeConverters.scala | 17 +- .../main/scala/torch/special/package.scala | 98 +++ core/src/main/scala/torch/torch.scala | 632 +++++++++++++++--- core/src/test/scala/torch/Generators.scala | 13 +- core/src/test/scala/torch/TensorSuite.scala | 446 ++++++++++++ 6 files changed, 1148 insertions(+), 106 deletions(-) create mode 100644 core/src/main/scala/torch/special/package.scala diff --git a/core/src/main/scala/torch/DType.scala b/core/src/main/scala/torch/DType.scala index 8a7b664b..22bad59e 100644 --- a/core/src/main/scala/torch/DType.scala +++ b/core/src/main/scala/torch/DType.scala @@ -215,8 +215,17 @@ type IntNN = Int8 | UInt8 | Int16 | Int32 | Int64 type ComplexNN = Complex32 | Complex64 | Complex128 -type ScalaType = Boolean | Byte | UByte | Short | Int | Long | Float | Double | Complex[Float] | - Complex[Double] +type BitwiseNN = Bool | IntNN + +type NumericRealNN = IntNN | FloatNN + +type RealNN = NumericRealNN | Bool + +type NumericNN = NumericRealNN | ComplexNN + +type Real = Boolean | Byte | UByte | Short | Int | Long | Float | Double + +type ScalaType = Real | Complex[Float] | Complex[Double] type DTypeToScala[T <: DType] <: ScalaType = T match case UInt8 => UByte @@ -289,6 +298,8 @@ type DTypeOrDeriveFromTensor[D1 <: DType, U <: DType | Derive] <: DType = U matc case Derive => D1 case U => TensorType[U] +type TensorOrReal[D <: DType] = Tensor[D] | Real + type PromotedDType[A <: DType, B <: DType] <: Float32 | Int32 | Int64 = (A, B) match case (Float64, B) => Float32 case (A, Float64) => Float32 @@ -366,15 +377,40 @@ type Promoted[T <: DType, U <: DType] <: DType = (T, U) match case (T, Complex128) => T case _ => DType +/** Promoted type for tensor operations that always output numbers (e.g. `square`) */ +type NumericPromoted[D <: DType] <: DType = D match + case Bool => Int64 + case _ => D + +/** Promoted type for tensor operations that always output floats (e.g. `sin`) */ +type FloatPromoted[D <: DType] <: FloatNN = D match + case Float64 => Float64 + case _ => Float32 + +/** Demoted type for complex to real type extractions (e.g. `imag`, `real`) */ +type ComplexToReal[D <: DType] <: DType = D match + case Complex32 => Float16 + case Complex64 => Float32 + case Complex128 => Float64 + case _ => D + +/** Promoted type for tensor operations that always output full sized complex or real (e.g. + * `floatPower`) + */ +type ComplexPromoted[T <: DType, U <: DType] <: Float64 | Complex128 = (T, U) match + case (ComplexNN, U) => Complex128 + case (T, ComplexNN) => Complex128 + case _ => Float64 + /** Promoted type for tensor division */ type Div[T <: DType, U <: DType] <: DType = (T, U) match - case (Bool | IntNN, Bool | IntNN) => Float32 - case _ => Promoted[T, U] + case (BitwiseNN, BitwiseNN) => Float32 + case _ => Promoted[T, U] /** Promoted type for elementwise tensor sum */ type Sum[D <: DType] <: DType = D match - case Bool | IntNN => Int64 - case D => D + case BitwiseNN => Int64 + case D => D private[torch] type TypedBuffer[T <: ScalaType] <: Buffer = T match case Short => ShortBuffer diff --git a/core/src/main/scala/torch/internal/NativeConverters.scala b/core/src/main/scala/torch/internal/NativeConverters.scala index fe0de940..0c7e6f42 100644 --- a/core/src/main/scala/torch/internal/NativeConverters.scala +++ b/core/src/main/scala/torch/internal/NativeConverters.scala @@ -18,12 +18,15 @@ package torch package internal import org.bytedeco.pytorch -import org.bytedeco.pytorch.ScalarTypeOptional -import org.bytedeco.pytorch.LayoutOptional -import org.bytedeco.pytorch.DeviceOptional -import org.bytedeco.pytorch.BoolOptional -import org.bytedeco.pytorch.LongOptional -import org.bytedeco.pytorch.TensorOptional +import org.bytedeco.pytorch.{ + ScalarTypeOptional, + LayoutOptional, + DeviceOptional, + DoubleOptional, + BoolOptional, + LongOptional, + TensorOptional +} import scala.reflect.Typeable import org.bytedeco.javacpp.LongPointer @@ -39,6 +42,8 @@ private[torch] object NativeConverters: case i: T => f(i) def toOptional(l: Long | Option[Long]): LongOptional = toOptional(l, pytorch.LongOptional(_)) + def toOptional(l: Double | Option[Double]): DoubleOptional = + toOptional(l, pytorch.DoubleOptional(_)) def toOptional[D <: DType](t: Tensor[D] | Option[Tensor[D]]): TensorOptional = toOptional(t, t => pytorch.TensorOptional(t.native)) diff --git a/core/src/main/scala/torch/special/package.scala b/core/src/main/scala/torch/special/package.scala new file mode 100644 index 00000000..588ebe0e --- /dev/null +++ b/core/src/main/scala/torch/special/package.scala @@ -0,0 +1,98 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch + +import org.bytedeco.pytorch.global.torch as torchNative + +import internal.NativeConverters.* + +package object special: + /** Computes the logarithmic derivative of the gamma function on `input`. */ + def digamma[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.digamma(input.native)) + + /** Computes the error function of `input`. */ + def erf[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.erf(input.native)) + + /** Computes the complementary error function of `input`. */ + def erfc[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.erfc(input.native)) + + /** Computes the inverse error function of `input`. The inverse error function is defined in the + * range (−1,1) + */ + def erfinv[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.erfinv(input.native)) + + /** Computes the base two exponential function of `input`. */ + def exp2[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.exp2(input.native)) + + /** Computes the exponential of the elements minus 1 of `input`. */ + def expm1[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.expm1(input.native)) + + /** Computes the zeroth order modified Bessel function of the first kind for each element of + * `input`. + */ + def i0[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.i0(input.native)) + + /** Computes the regularized lower incomplete gamma function */ + // NOTE it is named `gammainc` in pytorch torch.special + def igamma[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2] + ): Tensor[FloatPromoted[Promoted[D, D2]]] = + Tensor(torchNative.igamma(input.native, other.native)) + + /** Computes the regularized upper incomplete gamma function */ + // NOTE it is named `gamaincc` in pytorch torch.special + def igammac[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2] + ): Tensor[FloatPromoted[Promoted[D, D2]]] = + Tensor(torchNative.igammac(input.native, other.native)) + + /** Returns a new tensor with the logit of the elements of `input`. `input` is clamped to [eps, 1 + * \- eps] when eps is not None. When eps is None and input < 0 or input > 1, the function will + * yields NaN. + */ + def logit[D <: DType](input: Tensor[D], eps: Option[Double]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.logit(input.native, toOptional(eps))) + + /** Computes the multivariate log-gamma function with dimension p element-wise */ + // NOTE it is named `multigammaln` in pytorch torch.special + def mvlgamma[D <: DType](input: Tensor[D], p: Int): Tensor[FloatPromoted[D]] = + Tensor(torchNative.mvlgamma(input.native, p)) + + /** Computes the nth derivative of the digamma function on `input`. n≥0 is called the order of the + * polygamma function. + */ + def polygamma[D <: DType](n: Int, input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.polygamma(n, input.native)) + + /** Computes the expit (also known as the logistic sigmoid function) of the elements of `input`. + */ + // NOTE it is named `expit` in pytorch torch.special + def sigmoid[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.sigmoid(input.native)) + + /** Returns a new tensor with the normalized sinc of the elements of `input`. */ + def sinc[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.sinc(input.native)) diff --git a/core/src/main/scala/torch/torch.scala b/core/src/main/scala/torch/torch.scala index 3a2a28ab..5ca7f293 100644 --- a/core/src/main/scala/torch/torch.scala +++ b/core/src/main/scala/torch/torch.scala @@ -67,7 +67,6 @@ import scala.util.Using // // TODO as_tensor // // TODO as_strided // // TODO frombuffer -// def zeros(size: Int*): Tensor[Float32] = zeros(size.map(_.toLong)) /** Returns a tensor filled with the scalar value `0`, with the shape defined by the variable * argument `size`. @@ -76,6 +75,8 @@ import scala.util.Using * @tparam T * @return */ +// def zeros[D <: DType](size: Int*): Tensor[Float32] = +// zeros[D](size.toSeq) def zeros[D <: DType]( size: Seq[Int] | Int, dtype: D = float32, @@ -88,7 +89,7 @@ def zeros[D <: DType]( case s: Int => Array(s.toLong) Tensor( torchNative.torch_zeros( - nativeSize.toArray.map(_.toLong), + nativeSize, NativeConverters.tensorOptions(dtype, layout, device, requiresGrad) ) ) @@ -389,11 +390,11 @@ def rand[D <: FloatNN | ComplexNN]( * @param input * the size of `input` will determine size of the output tensor. * @param dtype - * the desired data type of returned Tensor. If `derive`, defaults to the dtype of input. + * the desired data type of returned Tensor. If `derive`, defaults to the dtype of `input`. * @param layout - * the desired layout of returned tensor. If `derive`, defaults to the layout of input. + * the desired layout of returned tensor. If `derive`, defaults to the layout of `input`. * @param device - * the desired device of returned tensor. If `derive` , defaults to the device of input. + * the desired device of returned tensor. If `derive` , defaults to the device of `input`. * @param requiresGrad * If autograd should record operations on the returned tensor. * @param memoryFormat @@ -552,109 +553,558 @@ def add[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Pr /** Adds `other` to `input`. */ def add[D <: DType, S <: ScalaType]( input: Tensor[D], - other: ScalaType + other: S ): Tensor[Promoted[D, ScalaToDType[S]]] = Tensor(torchNative.add(input.native, toScalar(other))) -// TODO addcdiv -// TODO addcmul -// TODO angle -// TODO asin -// TODO atan -// TODO atanh -// TODO atan2 -// TODO bitwise_not -// TODO bitwise_and -// TODO bitwise_or -// TODO bitwise_xor -// TODO bitwise_left_shift -// TODO bitwise_right_shift -// TODO ceil -// TODO clamp -// TODO conj_physical -// TODO copysign -// TODO cos -// TODO cosh -// TODO deg2rad -// TODO div -// TODO exp -// TODO exp2 -// TODO expm1 -// TODO fake_quantize_per_channel_affine -// TODO fake_quantize_per_tensor_affine -// TODO float_power -// TODO floor -// TODO floor_divide -// TODO fmod -// TODO frac -// TODO frexp -// TODO gradient -// TODO imag -// TODO ldexp -// TODO lerp -// TODO lgamma -// TODO log - -/** Returns a new tensor with the natural logarithm of the elements of input. */ -def log[D <: DType](input: Tensor[D]) = Tensor(torchNative.log(input.native)) - -// TODO log10 -// TODO log1p +/** Performs the element-wise division of tensor1 by tensor2, multiplies the result by the scalar + * value and adds it to input. + */ +def addcdiv[D <: DType, D2 <: DType, D3 <: DType]( + input: Tensor[D], + tensor1: Tensor[D2], + tensor2: Tensor[D3], + value: ScalaType +): Tensor[Promoted[D, Promoted[D2, D3]]] = + Tensor(torchNative.addcdiv(input.native, tensor1.native, tensor2.native, toScalar(value))) + +/** Performs the element-wise multiplication of tensor1 by tensor2, multiplies the result by the + * scalar value and adds it to input. + */ +def addcmul[D <: DType, D2 <: DType, D3 <: DType]( + input: Tensor[D], + tensor1: Tensor[D2], + tensor2: Tensor[D3], + value: ScalaType +): Tensor[Promoted[D, Promoted[D2, D3]]] = + Tensor(torchNative.addcmul(input.native, tensor1.native, tensor2.native, toScalar(value))) + +/** Computes the element-wise angle (in radians) of the given `input` tensor. */ +def angle[D <: DType](input: Tensor[D]): Tensor[ComplexToReal[D]] = + Tensor(torchNative.angle(input.native)) + +/** Returns a new tensor with the arcsine of the elements of `input`. */ +def asin[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.asin(input.native)) + +/** Returns a new tensor with the inverse hyperbolic sine of the elements of `input`. */ +def asinh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.asinh(input.native)) + +/** Returns a new tensor with the arctangent of the elements of `input`. */ +def atan[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.atan(input.native)) + +/** Returns a new tensor with the inverse hyperbolic tangent of the elements of `input`. */ +def atanh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.atanh(input.native)) + +/** Element-wise arctangent of (input / other) with consideration of the quadrant. Returns a new + * tensor with the signed angles in radians between vector (other, input) and vector (1, 0). (Note + * that other, the second parameter, is the x-coordinate, while input, the first parameter, is the + * y-coordinate.) + */ +def atan2[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[FloatPromoted[Promoted[D, D2]]] = + Tensor(torchNative.atan2(input.native, other.native)) + +/** Computes the bitwise NOT of the given input tensor. The input tensor must be of integral or + * Boolean types. For bool tensors, it computes the logical NOT. + */ +def bitwiseNot[D <: BitwiseNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.bitwise_not(input.native)) + +/** Computes the bitwise AND of `input` and `other`. For bool tensors, it computes the logical AND. + */ +def bitwiseAnd[D <: BitwiseNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = + Tensor(torchNative.bitwise_and(input.native, other.native)) + +/** Computes the bitwise OR of `input` and `other`. For bool tensors, it computes the logical OR. + */ +def bitwiseOr[D <: BitwiseNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = + Tensor(torchNative.bitwise_or(input.native, other.native)) + +/** Computes the bitwise XOR of `input` and `other`. For bool tensors, it computes the logical XOR. + */ +def bitwiseXor[D <: BitwiseNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = + Tensor(torchNative.bitwise_xor(input.native, other.native)) + +/** Computes the left arithmetic shift of `input` by `other` bits. */ +def bitwiseLeftShift[D <: IntNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = + Tensor(torchNative.bitwise_left_shift(input.native, other.native)) + +/** Computes the right arithmetic s\hift of `input` by `other` bits. */ +def bitwiseRightShift[D <: IntNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = + Tensor(torchNative.bitwise_right_shift(input.native, other.native)) + +/** Returns a new tensor with the ceil of the elements of `input`, the smallest integer greater than + * or equal to each element. + */ +def ceil[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.ceil(input.native)) + +/** Clamps all elements in input into the range [ min, max ]. Letting min_value and max_value be min + * and max, respectively, this returns: `min(max(input, min_value), max_value)` If min is None, + * there is no lower bound. Or, if max is None there is no upper bound. + */ +def clamp[D <: NumericNN]( + input: Tensor[D], + min: Option[Tensor[D]], + max: Option[Tensor[D]] +): Tensor[D] = + Tensor(torchNative.clamp(input.native, toOptional(min), toOptional(max))) + +/** Computes the element-wise conjugate of the given input tensor. If input has a non-complex dtype, + * this function just returns input. + */ +def conjPhysical[D <: NumericNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.conj_physical(input.native)) + +/** Create a new floating-point tensor with the magnitude of input and the sign of other, + * elementwise. + */ +// TODO +// def copysign[D <: DType](input: Tensor[D]): Tensor[D] = +// Tensor(torchNative.copysign(input.native)) + +/** Returns a new tensor with the cosine of the elements of `input`. */ +def cos[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.cos(input.native)) + +/** Returns a new tensor with the hyperbolic cosine of the elements of `input`. */ +def cosh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.cosh(input.native)) + +/** Returns a new tensor with each of the elements of `input` converted from angles in degrees to + * radians. + */ +def deg2rad[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.deg2rad(input.native)) + +/** Divides each element of the input `input` by the corresponding element of `other`. */ +def div[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[D] = + Tensor(torchNative.div(input.native, other.native)) + +export torch.special.digamma +export torch.special.erf +export torch.special.erfc +export torch.special.erfinv + +/** Returns a new tensor with the exponential of the elements of the input tensor `input`. */ +def exp[D <: RealNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.exp(input.native)) + +export torch.special.exp2 +export torch.special.expm1 + +/** Returns a new tensor with the data in `input` fake quantized per channel using `scale`, + * `zero_point`, `quant_min` and `quant_max`, across the channel specified by `axis`. + */ +// TODO Fix pytorch docs to add `axis` input +def fakeQuantizePerChannelAffine( + input: Tensor[Float32], + scale: Tensor[Float32], + zeroPoint: Tensor[Int32 | Float16 | Float32], + axis: Long, + quantMin: Long, + quantMax: Long +): Tensor[Float32] = + Tensor( + torchNative.fake_quantize_per_channel_affine( + input.native, + scale.native, + zeroPoint.native, + axis, + quantMin, + quantMax + ) + ) + +/** Returns a new tensor with the data in `input` fake quantized using `scale`, `zero_point`, + * `quant_min` and `quant_max`. + */ +def fakeQuantizePerTensorAffine( + input: Tensor[Float32], + scale: Tensor[Float32], + zeroPoint: Tensor[Int32], + quantMin: Long, + quantMax: Long +): Tensor[Float32] = + Tensor( + torchNative.fake_quantize_per_tensor_affine( + input.native, + scale.native, + zeroPoint.native, + quantMin, + quantMax + ) + ) + +def fakeQuantizePerTensorAffine( + input: Tensor[Float32], + scale: Double, + zeroPoint: Long, + quantMin: Long, + quantMax: Long +): Tensor[Float32] = + Tensor( + torchNative.fake_quantize_per_tensor_affine(input.native, scale, zeroPoint, quantMin, quantMax) + ) + +// TODO torch.fix // Alias for torch.trunc + +/** Raises `input` to the power of `exponent`, elementwise, in double precision. If neither input is + * complex returns a `torch.float64` tensor, and if one or more inputs is complex returns a + * `torch.complex128` tensor. + */ +def floatPower[D <: DType, D2 <: DType]( + input: Tensor[D], + exponent: Tensor[D2] +): Tensor[ComplexPromoted[D, D2]] = + Tensor(torchNative.float_power(input.native, exponent.native)) + +def floatPower[D <: DType, S <: ScalaType]( + input: S, + exponent: Tensor[D] +): Tensor[ComplexPromoted[ScalaToDType[S], D]] = + Tensor(torchNative.float_power(toScalar(input), exponent.native)) + +def floatPower[D <: DType, S <: ScalaType]( + input: Tensor[D], + exponent: ScalaType +): Tensor[ComplexPromoted[D, ScalaToDType[S]]] = + Tensor(torchNative.float_power(input.native, toScalar(exponent))) + +/** Returns a new tensor with the floor of the elements of `input`, the largest integer less than or + * equal to each element. + */ +def floor[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.floor(input.native)) + +/** Computes `input` divided by `other`, elementwise, and floors the result. */ +def floorDivide[D <: DType, D2 <: DType]( + input: Tensor[D], + other: TensorOrReal[D2] +): Tensor[Promoted[D, D2]] = + Tensor( + (input, other) match + case (input: Tensor[D], other: Tensor[D2]) => + torchNative.floor_divide(input.native, other.native) + case (input: Tensor[D], other: Real) => + torchNative.floor_divide(input.native, toScalar(other)) + ) + +/** Applies C++’s `std::fmod` entrywise. The result has the same sign as the dividend `input` and + * its absolute value is less than that of `other`. + */ +// NOTE: When the divisor is zero, returns NaN for floating point dtypes on both CPU and GPU; raises RuntimeError for integer division by zero on CPU; Integer division by zero on GPU may return any value. +def fmod[D <: RealNN](input: Tensor[D], other: TensorOrReal[D]): Tensor[D] = + Tensor( + other match + case (other: Tensor[D]) => + torchNative.fmod(input.native, other.native) + case (other: Real) => + torchNative.fmod(input.native, toScalar(other)) + ) + +/** Computes the fractional portion of each element in `input`. */ +def frac[D <: FloatNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.frac(input.native)) + +/** Decomposes `input` into `mantissa` and `exponent` tensors such that `input = mantissa * (2 ** + * exponent)` The range of mantissa is the open interval (-1, 1). + */ +def frexp[D <: FloatNN](input: Tensor[D]): (Tensor[FloatPromoted[D]], Tensor[Int32]) = + val nativeTuple = torchNative.frexp(input.native) + (Tensor(nativeTuple.get0), new Int32Tensor(nativeTuple.get1)) + +// TODO implement +/** */ +// def gradient[D <: DType](input: Tensor[D]): Tensor[D] = +// Tensor(torchNative.???) + +/** Returns a new tensor containing imaginary values of the `input` tensor. The returned tensor and + * `input` share the same underlying storage. + */ +def imag[D <: ComplexNN](input: Tensor[D]): Tensor[ComplexToReal[D]] = + Tensor(torchNative.imag(input.native)) + +/** Multiplies `input` by 2 ** `other`. */ +def ldexp[D <: DType](input: Tensor[D], other: Tensor[D]): Tensor[D] = + Tensor(torchNative.ldexp(input.native, other.native)) + +/** Does a linear interpolation of two tensors `start` (given by `input`) and `end` (given by + * `other`) based on a scalar or tensor weight and returns the resulting out tensor. out = start + + * weight × (end − start) + */ +def lerp[D <: DType]( + input: Tensor[D], + other: Tensor[D], + weight: Tensor[D] | Float | Double +): Tensor[D] = + Tensor( + weight match + case weight: Tensor[D] => torchNative.lerp(input.native, other.native, weight.native) + case weight: Float => torchNative.lerp(input.native, other.native, toScalar(weight)) + case weight: Double => torchNative.lerp(input.native, other.native, toScalar(weight)) + ) + +/** Computes the natural logarithm of the absolute value of the gamma function on `input`. */ +def lgamma[D <: RealNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.lgamma(input.native)) + +/** Returns a new tensor with the natural logarithm of the elements of `input`. */ +def log[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.log(input.native)) + +/** Returns a new tensor with the logarithm to the base 10 of the elements of `input`. */ +def log10[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.log10(input.native)) + +/** Returns a new tensor with the natural logarithm of (1 + input). */ +def log1p[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.log1p(input.native)) /** Returns a new tensor with the logarithm to the base 2 of the elements of `input`. */ -def log2[D <: DType](input: Tensor[D]) = Tensor(torchNative.log2(input.native)) - -// TODO logaddexp -// TODO logaddexp2 -// TODO logical_and -// TODO locigal_not -// TODO logical_or -// TODO logical_xor -// TODO logit -// TODO hypot -// TODO i0 -// TODO igamma -// TODO igammac +def log2[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.log2(input.native)) + +/** Logarithm of the sum of exponentiations of the inputs. Calculates pointwise log `log(e**x + + * e**y)`. This function is useful in statistics where the calculated probabilities of events may + * be so small as to exceed the range of normal floating point numbers. In such cases the logarithm + * of the calculated probability is stored. This function allows adding probabilities stored in + * such a fashion. This op should be disambiguated with `torch.logsumexp()` which performs a + * reduction on a single tensor. + */ +def logaddexp[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[Promoted[D, D2]] = + Tensor(torchNative.logaddexp(input.native, other.native)) + +/** Logarithm of the sum of exponentiations of the inputs in base-2. Calculates pointwise `log2(2**x + * + 2**y)`. See torch.logaddexp() for more details. + */ +def logaddexp2[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[Promoted[D, D2]] = + Tensor(torchNative.logaddexp2(input.native, other.native)) + +/** Computes the element-wise logical AND of the given input tensors. Zeros are treated as False and + * nonzeros are treated as True. + */ +def logicalAnd[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Bool] = + Tensor(torchNative.logical_and(input.native, other.native)) + +/** Computes the element-wise logical NOT of the given input tensor. TODO If not specified, the + * output tensor will have the bool dtype. If the input tensor is not a bool tensor, zeros are + * treated as False and non-zeros are treated as True. + */ +def logicalNot[D <: RealNN](input: Tensor[D]): Tensor[Bool] = + Tensor(torchNative.logical_not(input.native)) +/** Computes the element-wise logical OR of the given input tensors. Zeros are treated as False and + * nonzeros are treated as True. + */ +def logicalOr[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Bool] = + Tensor(torchNative.logical_or(input.native, other.native)) + +/** Computes the element-wise logical XOR of the given input tensors. Zeros are treated as False and + * nonzeros are treated as True. + */ +def logicalXor[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Bool] = + Tensor(torchNative.logical_or(input.native, other.native)) + +export torch.special.logit + +/** Given the legs of a right triangle, return its hypotenuse. */ +def hypot[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[FloatPromoted[Promoted[D, D2]]] = + Tensor(torchNative.hypot(input.native, other.native)) + +export torch.special.i0 +export torch.special.igamma +export torch.special.igammac + +/** Multiplies input by other. */ def mul[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Promoted[D, D2]] = Tensor(torchNative.mul(input.native, other.native)) -// TODO mlvgamma -// TODO nan_to_num -// TODO neg -// TODO nextafter -// TODO polygamma -// TODO positive -// TODO pow +export torch.special.mvlgamma + +/** Replaces NaN, positive infinity, and negative infinity values in `input` with the values + * specified by nan, posinf, and neginf, respectively. By default, NaNs are replaced with zero, + * positive infinity is replaced with the greatest finite value representable by input’s dtype, and + * negative infinity is replaced with the least finite value representable by input’s dtype. + */ +def nanToNum[D <: FloatNN]( + input: Tensor[D], + nan: Option[Double] = None, + posinf: Option[Double], + neginf: Option[Double] +): Tensor[D] = + Tensor( + torchNative.nan_to_num(input.native, toOptional(nan), toOptional(posinf), toOptional(neginf)) + ) + +/** Returns a new tensor with the negative of the elements of `input`. */ +def neg[D <: NumericNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.neg(input.native)) + +/** Return the next floating-point value after `input` towards `other`, elementwise. */ +def nextafter[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[FloatPromoted[Promoted[D, D2]]] = + Tensor(torchNative.nextafter(input.native, other.native)) + +export torch.special.polygamma + +/** Returns input. Normally throws a runtime error if input is a bool tensor in pytorch. */ +def positive[D <: NumericNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.positive(input.native)) + +/** Takes the power of each element in `input` with exponent and returns a tensor with the result. + * `exponent` can be either a single float number or a Tensor with the same number of elements as + * input. + */ +// TODO handle Scalar `input` +def pow[D <: DType, D2 <: DType]( + input: Tensor[D], + exponent: TensorOrReal[D2] +): Tensor[FloatPromoted[D]] = + Tensor( + (input, exponent) match + case (input: Tensor[D], exponent: Tensor[D2]) => + torchNative.pow(input.native, exponent.native) + case (input: Tensor[D], exponent: Real) => + torchNative.pow(input.native, toScalar(exponent)) + ) + // TODO quantized_batch_norm // TODO quantized_max_pool1d // TODO quantized_max_pool2d -// TODO rad2deg -// TODO real -// TODO reciprocal -// TODO remainder -// TODO round -// TODO rsqrt -// TODO sigmoid -// TODO sign -// TODO sgn -// TODO signbit - -def sin[D <: DType](input: Tensor[D]) = Tensor(torchNative.sin(input.native)) - -// TODO sinc -// TODO sinh + +/** Returns a new tensor with each of the elements of `input` converted from angles in radians to + * degrees. + */ +def rad2Deg[D <: RealNN | Bool](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.rad2deg(input.native)) + +/** Returns a new tensor containing real values of the self tensor. The returned tensor and self + * share the same underlying storage. + */ +def real[D <: DType](input: Tensor[D]): Tensor[ComplexToReal[D]] = + Tensor(torchNative.real(input.native)) + +/** Returns a new tensor with the reciprocal of the elements of `input` */ +def reciprocal[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.reciprocal(input.native)) + +/** Computes Python’s modulus operation entrywise. The result has the same sign as the divisor + * `other` and its absolute value is less than that of `other`. + */ +// TODO handle Scalar `input` +def remainder[D <: DType, D2 <: DType]( + input: Tensor[D], + other: TensorOrReal[D2] +): Tensor[FloatPromoted[D]] = + Tensor( + (input, other) match + case (input: Tensor[D], other: Tensor[D2]) => + torchNative.remainder(input.native, other.native) + case (input: Tensor[D], other: Real) => + torchNative.remainder(input.native, toScalar(other)) + ) + +/** Rounds elements of `input` to the nearest integer. If decimals is negative, it specifies the + * number of positions to the left of the decimal point. + */ +def round[D <: NumericNN](input: Tensor[D], decimals: Long = 0): Tensor[D] = + Tensor(torchNative.round(input.native, decimals)) + +/** Returns a new tensor with the reciprocal of the square-root of each of the elements of `input`. + */ +def rsqrt[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.rsqrt(input.native)) + +export torch.special.sigmoid + +/** Returns a new tensor with the signs of the elements of `input`. */ +def sign[D <: RealNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.sign(input.native)) + +/** This function is an extension of `torch.sign()` to complex tensors. It computes a new tensor + * whose elements have the same angles as the corresponding elements of `input` and absolute values + * (i.e. magnitudes) of one for complex tensors and is equivalent to torch.sign() for non-complex + * tensors. + */ +def sgn[D <: DType](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.sgn(input.native)) + +/** Tests if each element of `input`` has its sign bit set or not. */ +def signbit[D <: RealNN](input: Tensor[D]): Tensor[Bool] = + Tensor(torchNative.signbit(input.native)) + +/** Returns a new tensor with the sine of the elements of `input`. */ +def sin[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.sin(input.native)) + +export torch.special.sinc + +/** Returns a new tensor with the hyperbolic sine of the elements of `input`. */ +def sinh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.sinh(input.native)) + // TODO softmax export torch.nn.functional.softmax -// TODO sqrt -// TODO square -// TODO sub -// TODO tan -// TODO tanh +/** Returns a new tensor with the square-root of the elements of `input`. */ +def sqrt[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.sqrt(input.native)) + +/** Returns a new tensor with the square of the elements of `input`. */ +def square[D <: RealNN](input: Tensor[D]): Tensor[NumericPromoted[D]] = + Tensor(torchNative.square(input.native)) + +/** Subtracts `other`, scaled by `alpha`, from `input`. */ +def sub[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Promoted[D, D2]] = + Tensor(torchNative.sub(input.native, other.native)) + +def sub[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2], + alpha: ScalaType +): Tensor[Promoted[D, D2]] = + Tensor(torchNative.sub(input.native, other.native, toScalar(alpha))) + +def sub[D <: DType, S <: ScalaType]( + input: Tensor[D], + other: S, + alpha: ScalaType +): Tensor[Promoted[D, ScalaToDType[S]]] = + Tensor(torchNative.sub(input.native, toScalar(other), toScalar(alpha))) + +/** Returns a new tensor with the tangent of the elements of `input`. */ +def tan[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.tan(input.native)) + +/** Returns a new tensor with the hyperbolic tangent of the elements of `input`. */ +def tanh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + Tensor(torchNative.tanh(input.native)) + // TODO true_divide -// TODO trunc + +/** Returns a new tensor with the truncated integer values of the elements of `input`. */ +def trunc[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.trunc(input.native)) + // TODO xlogy // End Pointwise Ops diff --git a/core/src/test/scala/torch/Generators.scala b/core/src/test/scala/torch/Generators.scala index a63ea608..953d75d8 100644 --- a/core/src/test/scala/torch/Generators.scala +++ b/core/src/test/scala/torch/Generators.scala @@ -32,7 +32,7 @@ object Generators: val genTensorSize = Gen.choose(0, 5).flatMap(listSize => Gen.listOfN(listSize, genDimSize)) given Arbitrary[Device] = Arbitrary(genDevice) - val genDType = Gen.oneOf( + val allDTypes = List( int8, uint8, int16, @@ -47,10 +47,17 @@ object Generators: // qint8, // quint8, // qint32, - bfloat16, + bfloat16 // quint4x2, - float16 + // float16, // NOTE: A lot of CPU do not support this dtype // undefined, // numoptions ) + + inline def genTensor[D <: DType]: Gen[Tensor[D]] = + Gen.oneOf(allDTypes.filter(_.isInstanceOf[D])).map { dtype => + ones(10, dtype = dtype.asInstanceOf[D]) + } + + val genDType = Gen.oneOf(allDTypes) given Arbitrary[DType] = Arbitrary(genDType) diff --git a/core/src/test/scala/torch/TensorSuite.scala b/core/src/test/scala/torch/TensorSuite.scala index 4362220b..cb369883 100644 --- a/core/src/test/scala/torch/TensorSuite.scala +++ b/core/src/test/scala/torch/TensorSuite.scala @@ -30,9 +30,58 @@ import Gen._ import Arbitrary.arbitrary import DeviceType.CPU import Generators.{*, given} +import scala.util.Try +import spire.math.Complex +import spire.implicits.DoubleAlgebra class TensorSuite extends ScalaCheckSuite { + inline private def testUnaryOp[In <: DType, InS <: ScalaType]( + op: Tensor[In] => Tensor[?], + opName: String, + inline inputTensor: Tensor[ScalaToDType[InS]], + inline expectedTensor: Tensor[?], + absolutePrecision: Double = 1e-04 + )(using ScalaToDType[InS] <:< In): Unit = + val propertyTestName = s"${opName}.property-test" + test(propertyTestName) { + forAll(genTensor[In]) { (tensor) => + val result = Try(op(tensor)) + // TODO Validate output types + assert( + result.isSuccess, + s"""| + |Tensor operation 'torch.${opName}' does not support ${tensor.dtype} inputs + | + |${result.failed.get} + """.stripMargin + ) + } + } + val unitTestName = s"${opName}.unit-test" + test(unitTestName) { + val outputTensor = op(inputTensor.asInstanceOf[Tensor[In]]) + val allclose = outputTensor.allclose( + other = expectedTensor, + atol = absolutePrecision, + equalNan = true + ) + assert( + allclose, + s"""| + |Tensor results are not all close for 'torch.${opName}' + | + |Input tensor: + |${inputTensor} + | + |Output tensor: + |${outputTensor} + | + |Expected tensor: + |${expectedTensor}""".stripMargin + ) + } + test("arange") { val t0 = arange(0, 10) assertEquals(t0.toSeq, Seq.range(0, 10)) @@ -107,6 +156,403 @@ class TensorSuite extends ScalaCheckSuite { assert(t.grad.equal(torch.ones(Seq(3)))) } + // TODO addcdiv + // TODO addcmul + // TODO angle + + testUnaryOp( + op = asin, + opName = "asin", + inputTensor = Tensor(Seq(-0.5962, 1.4985, -0.4396, 1.4525)), + expectedTensor = Tensor(Seq(-0.6387, Double.NaN, -0.4552, Double.NaN)) + ) + testUnaryOp( + op = asinh, + opName = "asinh", + inputTensor = Tensor(Seq(0.1606, -1.4267, -1.0899, -1.0250)), + expectedTensor = Tensor(Seq(0.1599, -1.1534, -0.9435, -0.8990)) + ) + testUnaryOp( + op = atan, + opName = "atan", + inputTensor = Tensor(Seq(0.2341, 0.2539, -0.6256, -0.6448)), + expectedTensor = Tensor(Seq(0.2299, 0.2487, -0.5591, -0.5727)) + ) + testUnaryOp( + op = atanh, + opName = "atanh", + inputTensor = Tensor(Seq(-0.9385, 0.2968, -0.8591, -0.1871)), + expectedTensor = Tensor(Seq(-1.7253, 0.3060, -1.2899, -0.1893)) + ) + + // TODO atan2 + + // TODO Test boolean cases for bitwise_not + // https://pytorch.org/docs/stable/generated/torch.bitwise_not.html + testUnaryOp( + op = bitwiseNot, + opName = "bitwiseNot", + inputTensor = Tensor(Seq(-1, -2, 3)), + expectedTensor = Tensor(Seq(0, 1, -4)) + ) + + // TODO bitwise_and + // TODO bitwise_or + // TODO bitwise_xor + // TODO bitwise_left_shift + // TODO bitwise_right_shift + + testUnaryOp( + op = ceil, + opName = "ceil", + inputTensor = Tensor(Seq(-0.6341, -1.4208, -1.0900, 0.5826)), + expectedTensor = Tensor(Seq(-0.0, -1.0, -1.0, 1.0)) + ) + + // TODO clamp + + // TODO Handle Complex Tensors + // testUnaryOp( + // op = conjPhysical, + // opName = "conjPhysical", + // inputTensor = Tensor(Seq(180.0, -180.0, 360.0, -360.0, 90.0, -90.0)), + // expectedTensor = Tensor(Seq(3.1416, -3.1416, 6.2832, -6.2832, 1.5708, -1.5708)) + // ) + + // TODO copysign + + testUnaryOp( + op = cos, + opName = "cos", + inputTensor = Tensor(Seq(1.4309, 1.2706, -0.8562, 0.9796)), + expectedTensor = Tensor(Seq(0.1395, 0.2957, 0.6553, 0.5574)) + ) + + testUnaryOp( + op = cosh, + opName = "cosh", + inputTensor = Tensor(Seq(0.1632, 1.1835, -0.6979, -0.7325)), + expectedTensor = Tensor(Seq(1.0133, 1.7860, 1.2536, 1.2805)) + ) + + testUnaryOp( + op = deg2rad, + opName = "deg2rad", + inputTensor = Tensor(Seq(180.0, -180.0, 360.0, -360.0, 90.0, -90.0)), + expectedTensor = Tensor(Seq(3.1416, -3.1416, 6.2832, -6.2832, 1.5708, -1.5708)) + ) + + // TODO div + + testUnaryOp( + op = digamma, + opName = "digamma", + inputTensor = Tensor(Seq(1, 0.5)), + expectedTensor = Tensor(Seq(-0.5772, -1.9635)) + ) + + testUnaryOp( + op = erf, + opName = "erf", + inputTensor = Tensor(Seq(0, -1.0, 10.0)), + expectedTensor = Tensor(Seq(0.0, -0.8427, 1.0)) + ) + + testUnaryOp( + op = erfc, + opName = "erfc", + inputTensor = Tensor(Seq(0, -1.0, 10.0)), + expectedTensor = Tensor(Seq(1.0, 1.8427, 0.0)) + ) + + testUnaryOp( + op = erfinv, + opName = "erfinv", + inputTensor = Tensor(Seq(0.0, 0.5, -1.0)), + expectedTensor = Tensor(Seq(0.0, 0.4769, Double.NegativeInfinity)) + ) + + testUnaryOp( + op = exp, + opName = "exp", + inputTensor = Tensor(Seq(0, 0.6931)), + expectedTensor = Tensor(Seq(1.0, 2.0)) + ) + + testUnaryOp( + op = exp2, + opName = "exp2", + inputTensor = Tensor(Seq(0.0, 1.0, 3.0, 4.0)), + expectedTensor = Tensor(Seq(1.0, 2.0, 8.0, 16.0)) + ) + + testUnaryOp( + op = expm1, + opName = "expm1", + inputTensor = Tensor(Seq(0, 0.6931)), + expectedTensor = Tensor(Seq(0.0, 1.0)) + ) + + // TODO fakeQuantizePerChannelAffine + // TODO fakeQuantizePerTensorAffine + // TODO floatPower + + testUnaryOp( + op = floor, + opName = "floor", + inputTensor = Tensor(Seq(-0.8166, 1.5308, -0.2530, -0.2091)), + expectedTensor = Tensor(Seq(-1.0, 1.0, -1.0, -1.0)) + ) + + // TODO floorDivide + // TODO fmod + + testUnaryOp( + op = frac, + opName = "frac", + inputTensor = Tensor(Seq(1, 2.5, -3.2)), + expectedTensor = Tensor(Seq(0.0, 0.5, -0.2)) + ) + + // TODO Handle Tuple Tensor Output + // https://pytorch.org/docs/stable/generated/torch.frexp.html + // testUnaryOp( + // op = frexp, + // opName = "frexp", + // inputTensor = Tensor(Seq(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)), + // expectedTensor = Tensor(Seq(0.5724, 0.0, -0.1208)) + // ) + + // TODO gradient + + // TODO Handle Complex Tensors + // testUnaryOp( + // op = imag, + // opName = "imag", + // inputTensor = Tensor(Seq((0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j))), + // expectedTensor = Tensor(Seq(0.3553, -0.7896, -0.0633, -0.8119)) + // ) + + // TODO ldexp + // TODO lerp + + testUnaryOp( + op = lgamma, + opName = "lgamma", + inputTensor = Tensor(Seq(0.5, 1.0, 1.5)), + expectedTensor = Tensor(Seq(0.5724, 0.0, -0.1208)) + ) + + testUnaryOp( + op = log, + opName = "log", + inputTensor = Tensor(Seq(4.7767, 4.3234, 1.2156, 0.2411, 4.5739)), + expectedTensor = Tensor(Seq(1.5637, 1.4640, 0.1952, -1.4226, 1.5204)) + ) + + testUnaryOp( + op = log10, + opName = "log10", + inputTensor = Tensor(Seq(0.5224, 0.9354, 0.7257, 0.1301, 0.2251)), + expectedTensor = Tensor(Seq(-0.2820, -0.0290, -0.1392, -0.8857, -0.6476)) + ) + + testUnaryOp( + op = log1p, + opName = "log1p", + inputTensor = Tensor(Seq(-1.0090, -0.9923, 1.0249, -0.5372, 0.2492)), + expectedTensor = Tensor(Seq(Double.NaN, -4.8653, 0.7055, -0.7705, 0.2225)), + absolutePrecision = 1e-2 + ) + + testUnaryOp( + op = log2, + opName = "log2", + inputTensor = Tensor(Seq(0.8419, 0.8003, 0.9971, 0.5287, 0.0490)), + expectedTensor = Tensor(Seq(-0.2483, -0.3213, -0.0042, -0.9196, -4.3504)), + absolutePrecision = 1e-2 + ) + + // TODO logaddexp + // TODO logaddexp2 + // TODO logicalAnd + + // TODO Handle numeric cases for logical_not + // https://pytorch.org/docs/stable/generated/torch.logical_not.html + testUnaryOp( + op = logicalNot, + opName = "logicalNot", + inputTensor = Tensor(Seq(true, false)), + expectedTensor = Tensor(Seq(false, true)) + ) + + // TODO logicalOr + // TODO logicalXor + // TODO logit + // TODO hypot + + testUnaryOp( + op = i0, + opName = "i0", + inputTensor = Tensor(Seq(0.0, 1.0, 2.0, 3.0, 4.0)), + expectedTensor = Tensor(Seq(1.0, 1.2661, 2.2796, 4.8808, 11.3019)) + ) + + // TODO igamma + // TODO igammac + // TODO mul + // TODO mvlgamma + // TODO nanToNum + + testUnaryOp( + op = neg, + opName = "neg", + inputTensor = Tensor(Seq(0.0090, -0.2262, -0.0682, -0.2866, 0.3940)), + expectedTensor = Tensor(Seq(-0.0090, 0.2262, 0.0682, 0.2866, -0.3940)) + ) + + // TODO nextafter + // TODO polygamma + + testUnaryOp( + op = positive, + opName = "positive", + inputTensor = Tensor(Seq(0.0090, -0.2262, -0.0682, -0.2866, 0.3940)), + expectedTensor = Tensor(Seq(0.0090, -0.2262, -0.0682, -0.2866, 0.3940)) + ) + + // TODO pow + // TODO quantized_batch_norm + // TODO quantized_max_pool1d + // TODO quantized_max_pool2d + + testUnaryOp( + op = rad2Deg, + opName = "rad2Deg", + inputTensor = Tensor(Seq(3.142, -3.142, 6.283, -6.283, 1.570, -1.570)), + expectedTensor = Tensor(Seq(180.0233, -180.0233, 359.9894, -359.9894, 89.9544, -89.9544)) + ) + + // TODO Handle Complex Tensors + // testUnaryOp( + // op = real, + // opName = "real", + // inputTensor = Tensor(Seq((0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j))), + // expectedTensor = Tensor(Seq(0.3100, -0.5445, -1.6492, -0.0638)) + // ) + + testUnaryOp( + op = reciprocal, + opName = "reciprocal", + inputTensor = Tensor(Seq(-0.4595, -2.1219, -1.4314, 0.7298)), + expectedTensor = Tensor(Seq(-2.1763, -0.4713, -0.6986, 1.3702)) + ) + + // TODO remainder + // TODO round + + testUnaryOp( + op = rsqrt, + opName = "rsqrt", + inputTensor = Tensor(Seq(-0.0370, 0.2970, 1.5420, -0.9105)), + expectedTensor = Tensor(Seq(Double.NaN, 1.8351, 0.8053, Double.NaN)), + absolutePrecision = 1e-3 + ) + + // TODO sigmoid + + testUnaryOp( + op = sign, + opName = "sign", + inputTensor = Tensor(Seq(0.7, -1.2, 0.0, 2.3)), + expectedTensor = Tensor(Seq(1.0, -1.0, 0.0, 1.0)) + ) + + // TODO Fix Complex Tensor creation + // testUnaryOp( + // op = sgn, + // opName = "sgn", + // inputTensor = Tensor(Seq(Complex(3.0,4.0), Complex(7.0, -24.0), Complex(0.0, 0.0), Complex(1.0, 2.0))), + // expectedTensor = Tensor(Seq(false, true, false, true, false)) + // ) + + testUnaryOp( + op = signbit, + opName = "signbit", + inputTensor = Tensor(Seq(0.7, -1.2, 0.0, -0.0, 2.3)), + expectedTensor = Tensor(Seq(false, true, false, true, false)) + ) + + testUnaryOp( + op = sin, + opName = "sin", + inputTensor = Tensor(Seq(-0.5461, 0.1347, -2.7266, -0.2746)), + expectedTensor = Tensor(Seq(-0.5194, 0.1343, -0.4032, -0.2711)) + ) + + testUnaryOp( + op = sinc, + opName = "sinc", + inputTensor = Tensor(Seq(0.2252, -0.2948, 1.0267, -1.1566)), + expectedTensor = Tensor(Seq(0.9186, 0.8631, -0.0259, -0.1300)) + ) + + testUnaryOp( + op = sinh, + opName = "sinh", + inputTensor = Tensor(Seq(0.5380, -0.8632, -0.1265, 0.9399)), + expectedTensor = Tensor(Seq(0.5644, -0.9744, -0.1268, 1.0845)) + ) + + testUnaryOp( + op = sqrt, + opName = "sqrt", + inputTensor = Tensor(Seq(-2.0755, 1.0226, 0.0831, 0.4806)), + expectedTensor = Tensor(Seq(Double.NaN, 1.0112, 0.2883, 0.6933)) + ) + + testUnaryOp( + op = square, + opName = "square", + inputTensor = Tensor(Seq(-2.0755, 1.0226, 0.0831, 0.4806)), + expectedTensor = Tensor(Seq(4.3077, 1.0457, 0.0069, 0.2310)) + ) + + test("sub") { + val a = Tensor(Seq(1, 2)) + val b = Tensor(Seq(0, 1)) + val res = sub(a, b) + assertEquals(res, Tensor(Seq(1, 1))) + + val resAlpha = sub(a, b, alpha = 2) + assertEquals( + resAlpha, + Tensor(Seq(1, 0)) + ) + } + + testUnaryOp( + op = tan, + opName = "tan", + inputTensor = Tensor(Seq(-1.2027, -1.7687, 0.4412, -1.3856)), + expectedTensor = Tensor(Seq(-2.5930, 4.9859, 0.4722, -5.3366)), + absolutePrecision = 1e-2 + ) + + testUnaryOp( + op = tanh, + opName = "tanh", + inputTensor = Tensor(Seq(0.8986, -0.7279, 1.1745, 0.2611)), + expectedTensor = Tensor(Seq(0.7156, -0.6218, 0.8257, 0.2553)) + ) + + testUnaryOp( + op = trunc, + opName = "trunc", + inputTensor = Tensor(Seq(3.4742, 0.5466, -0.8008, -0.9079)), + expectedTensor = Tensor(Seq(3.0, 0.0, -0.0, -0.0)) + ) + test("indexing") { val tensor = torch.arange(0, 16).reshape(4, 4) // first row From b5c21c84a0ef986b3801343a8ba5c6123a454f04 Mon Sep 17 00:00:00 2001 From: David Gomez-Urquiza Date: Sun, 28 May 2023 18:14:52 -0600 Subject: [PATCH 4/5] Fix complex number Tensor creation and item accessor + enable tests that use it --- build.sbt | 1 + core/src/main/scala/torch/Tensor.scala | 40 ++++++++--- core/src/main/scala/torch/Types.scala | 36 ++++++++++ .../torch/internal/NativeConverters.scala | 5 +- core/src/test/scala/torch/TensorSuite.scala | 69 +++++++++++-------- 5 files changed, 109 insertions(+), 42 deletions(-) create mode 100644 core/src/main/scala/torch/Types.scala diff --git a/build.sbt b/build.sbt index 4664793b..95d4f127 100644 --- a/build.sbt +++ b/build.sbt @@ -80,6 +80,7 @@ lazy val core = project libraryDependencies ++= Seq( "org.bytedeco" % "pytorch" % s"$pytorchVersion-${javaCppVersion.value}", "org.typelevel" %% "spire" % "0.18.0", + "org.typelevel" %% "shapeless3-typeable" % "3.2.0", "com.lihaoyi" %% "os-lib" % "0.9.0", "com.lihaoyi" %% "sourcecode" % "0.3.0", "dev.dirs" % "directories" % "26", diff --git a/core/src/main/scala/torch/Tensor.scala b/core/src/main/scala/torch/Tensor.scala index 5ab21af0..26e2f0fd 100644 --- a/core/src/main/scala/torch/Tensor.scala +++ b/core/src/main/scala/torch/Tensor.scala @@ -321,17 +321,21 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto def item: DTypeToScala[D] = import ScalarType.* val out = native.dtype().toScalarType().intern() match - case Byte => UByte(native.item_int()) - case Char => native.item_byte() - case Short => native.item_short() - case Int => native.item_int() - case Long => native.item_long() - case Half => native.item().toHalf.asFloat() - case Float => native.item_float() - case Double => native.item_double() - case ComplexHalf => ??? // TODO how to access complex scalar values? - case ComplexFloat => ??? - case ComplexDouble => ??? + case Byte => UByte(native.item_int()) + case Char => native.item_byte() + case Short => native.item_short() + case Int => native.item_int() + case Long => native.item_long() + case Half => native.item().toHalf.asFloat() + case Float => native.item_float() + case Double => native.item_double() + case ComplexHalf => ??? // TODO how to access complex scalar values? + case ComplexFloat => + val b = native.contiguous.createBuffer[FloatBuffer] + Complex(b.get(), b.get()) + case ComplexDouble => + val b = native.contiguous.createBuffer[DoubleBuffer] + Complex(b.get(), b.get()) case Bool => native.item().toBool case QInt8 => native.item_byte() case QUInt8 => native.item_short() @@ -771,6 +775,20 @@ object Tensor: case longs: Array[Long] => (new LongPointer(LongBuffer.wrap(longs)), int64) case floats: Array[Float] => (new FloatPointer(FloatBuffer.wrap(floats)), float32) case doubles: Array[Double] => (new DoublePointer(DoubleBuffer.wrap(doubles)), float64) + case complexFloatArray(complexFloats) => + ( + new FloatPointer( + FloatBuffer.wrap(complexFloats.flatMap(c => Array(c.real, c.imag))) + ), + complex64 + ) + case complexDoubleArray(complexDoubles) => + ( + new DoublePointer( + DoubleBuffer.wrap(complexDoubles.flatMap(c => Array(c.real, c.imag))) + ), + complex128 + ) case _ => throw new IllegalArgumentException(s"Unsupported sequence type") Tensor( torchNative diff --git a/core/src/main/scala/torch/Types.scala b/core/src/main/scala/torch/Types.scala new file mode 100644 index 00000000..99880a68 --- /dev/null +++ b/core/src/main/scala/torch/Types.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch + +import shapeless3.typeable.{TypeCase, Typeable} +import shapeless3.typeable.syntax.typeable.* +import spire.math.Complex + +/* Typeable instance for Array[T] + * NOTE: It needs to iterate through the whole array to validate casteability + */ +given iterableTypeable[T](using tt: Typeable[T]): Typeable[Array[T]] with + def castable(t: Any): Boolean = + t match + case (arr: Array[?]) => + arr.forall(_.castable[T]) + case _ => false + def describe = s"Array[${tt.describe}]" + +/* TypeCase helpers to perform pattern matching on `Complex` higher kinded types */ +val complexDoubleArray = TypeCase[Array[Complex[Double]]] +val complexFloatArray = TypeCase[Array[Complex[Float]]] diff --git a/core/src/main/scala/torch/internal/NativeConverters.scala b/core/src/main/scala/torch/internal/NativeConverters.scala index 0c7e6f42..8e631eaf 100644 --- a/core/src/main/scala/torch/internal/NativeConverters.scala +++ b/core/src/main/scala/torch/internal/NativeConverters.scala @@ -65,9 +65,8 @@ private[torch] object NativeConverters: case x: Long => pytorch.Scalar(x) case x: Float => pytorch.Scalar(x) case x: Double => pytorch.Scalar(x) - case x @ Complex(r: Float, i: Float) => ??? - case x @ Complex(r: Double, i: Double) => ??? - + case x @ Complex(r: Float, i: Float) => Tensor(Seq(x)).to(dtype = complex64).native.item() + case x @ Complex(r: Double, i: Double) => Tensor(Seq(x)).to(dtype = complex128).native.item() def tensorOptions( dtype: DType, layout: Layout, diff --git a/core/src/test/scala/torch/TensorSuite.scala b/core/src/test/scala/torch/TensorSuite.scala index cb369883..4f8e7609 100644 --- a/core/src/test/scala/torch/TensorSuite.scala +++ b/core/src/test/scala/torch/TensorSuite.scala @@ -211,13 +211,12 @@ class TensorSuite extends ScalaCheckSuite { // TODO clamp - // TODO Handle Complex Tensors - // testUnaryOp( - // op = conjPhysical, - // opName = "conjPhysical", - // inputTensor = Tensor(Seq(180.0, -180.0, 360.0, -360.0, 90.0, -90.0)), - // expectedTensor = Tensor(Seq(3.1416, -3.1416, 6.2832, -6.2832, 1.5708, -1.5708)) - // ) + testUnaryOp( + op = conjPhysical, + opName = "conjPhysical", + inputTensor = Tensor(Seq(Complex(-1.0, 1.0), Complex(-2.0, 2.0), Complex(3.0, -3.0))), + expectedTensor = Tensor(Seq(Complex(-1.0, -1.0), Complex(-2.0, -2.0), Complex(3.0, 3.0))) + ) // TODO copysign @@ -325,13 +324,19 @@ class TensorSuite extends ScalaCheckSuite { // TODO gradient - // TODO Handle Complex Tensors - // testUnaryOp( - // op = imag, - // opName = "imag", - // inputTensor = Tensor(Seq((0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j))), - // expectedTensor = Tensor(Seq(0.3553, -0.7896, -0.0633, -0.8119)) - // ) + testUnaryOp( + op = imag, + opName = "imag", + inputTensor = Tensor( + Seq( + Complex(0.31, 0.3553), + Complex(-0.5445, -0.7896), + Complex(-1.6492, -0.0633), + Complex(-0.0638, -0.8119) + ) + ), + expectedTensor = Tensor(Seq(0.3553, -0.7896, -0.0633, -0.8119)) + ) // TODO ldexp // TODO lerp @@ -433,13 +438,19 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(180.0233, -180.0233, 359.9894, -359.9894, 89.9544, -89.9544)) ) - // TODO Handle Complex Tensors - // testUnaryOp( - // op = real, - // opName = "real", - // inputTensor = Tensor(Seq((0.3100+0.3553j), (-0.5445-0.7896j), (-1.6492-0.0633j), (-0.0638-0.8119j))), - // expectedTensor = Tensor(Seq(0.3100, -0.5445, -1.6492, -0.0638)) - // ) + testUnaryOp( + op = real, + opName = "real", + inputTensor = Tensor( + Seq( + Complex(0.31, 0.3553), + Complex(-0.5445, -0.7896), + Complex(-1.6492, -0.0633), + Complex(-0.0638, -0.8119) + ) + ), + expectedTensor = Tensor(Seq(0.3100, -0.5445, -1.6492, -0.0638)) + ) testUnaryOp( op = reciprocal, @@ -468,13 +479,15 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(1.0, -1.0, 0.0, 1.0)) ) - // TODO Fix Complex Tensor creation - // testUnaryOp( - // op = sgn, - // opName = "sgn", - // inputTensor = Tensor(Seq(Complex(3.0,4.0), Complex(7.0, -24.0), Complex(0.0, 0.0), Complex(1.0, 2.0))), - // expectedTensor = Tensor(Seq(false, true, false, true, false)) - // ) + testUnaryOp( + op = sgn, + opName = "sgn", + inputTensor = + Tensor(Seq(Complex(3.0, 4.0), Complex(7.0, -24.0), Complex(0.0, 0.0), Complex(1.0, 2.0))), + expectedTensor = Tensor( + Seq(Complex(0.6, 0.8), Complex(0.28, -0.96), Complex(0.0, 0.0), Complex(0.4472, 0.8944)) + ) + ) testUnaryOp( op = signbit, From a81b76f22c1b8c25c53a5fd1dfbda28508c9425b Mon Sep 17 00:00:00 2001 From: David Gomez-Urquiza Date: Sat, 3 Jun 2023 23:42:13 -0600 Subject: [PATCH 5/5] Add more Pointwise ops, improve property test helpers, and type checking Implementation - Add more Pointwise ops, there are still a few missing, but a big progress was made - Add type Evidences for operations that require (1) At most one bool, and (2) at least one float - Add converter for ScalarOptional Testing - Now property tests perform complement (negative) property tests as well, in order to verify that our type annotation is not too strict. This uncovered some bugs, but also a few type-level challenges still unsolved (see new type Evidences, such as AtLeastOneFloat) - Disable complex32 dtype generator as a lot of operations does not support it, apparently due to experimental usage - Add test helper for binary ops --- core/src/main/scala/torch/DType.scala | 16 +- core/src/main/scala/torch/Types.scala | 11 + .../torch/internal/NativeConverters.scala | 13 +- .../main/scala/torch/special/package.scala | 34 +- core/src/main/scala/torch/torch.scala | 284 ++++++---- core/src/test/scala/torch/Generators.scala | 4 +- .../test/scala/torch/TensorCheckSuite.scala | 193 +++++++ core/src/test/scala/torch/TensorSuite.scala | 533 ++++++++++++++---- 8 files changed, 852 insertions(+), 236 deletions(-) create mode 100644 core/src/test/scala/torch/TensorCheckSuite.scala diff --git a/core/src/main/scala/torch/DType.scala b/core/src/main/scala/torch/DType.scala index 22bad59e..f32e3773 100644 --- a/core/src/main/scala/torch/DType.scala +++ b/core/src/main/scala/torch/DType.scala @@ -32,7 +32,7 @@ import java.nio.{ } import scala.annotation.{targetName, unused} import scala.reflect.ClassTag -import spire.math.* +import spire.math.{Complex, UByte} import scala.compiletime.{erasedValue, summonFrom} @@ -209,6 +209,7 @@ private object Derive: val derive: Derive = Derive() export Derive.derive +/** DType combinations * */ type FloatNN = Float16 | Float32 | Float64 | BFloat16 type IntNN = Int8 | UInt8 | Int16 | Int32 | Int64 @@ -223,9 +224,16 @@ type RealNN = NumericRealNN | Bool type NumericNN = NumericRealNN | ComplexNN -type Real = Boolean | Byte | UByte | Short | Int | Long | Float | Double +/** Scala type combinations * */ +type NumericReal = Byte | UByte | Short | Int | Long | Float | Double -type ScalaType = Real | Complex[Float] | Complex[Double] +type Real = NumericReal | Boolean + +type ComplexScala = Complex[Float] | Complex[Double] + +type Numeric = NumericReal | ComplexScala + +type ScalaType = Real | ComplexScala type DTypeToScala[T <: DType] <: ScalaType = T match case UInt8 => UByte @@ -298,8 +306,6 @@ type DTypeOrDeriveFromTensor[D1 <: DType, U <: DType | Derive] <: DType = U matc case Derive => D1 case U => TensorType[U] -type TensorOrReal[D <: DType] = Tensor[D] | Real - type PromotedDType[A <: DType, B <: DType] <: Float32 | Int32 | Int64 = (A, B) match case (Float64, B) => Float32 case (A, Float64) => Float32 diff --git a/core/src/main/scala/torch/Types.scala b/core/src/main/scala/torch/Types.scala index 99880a68..c32a9823 100644 --- a/core/src/main/scala/torch/Types.scala +++ b/core/src/main/scala/torch/Types.scala @@ -18,6 +18,7 @@ package torch import shapeless3.typeable.{TypeCase, Typeable} import shapeless3.typeable.syntax.typeable.* +import scala.util.NotGiven import spire.math.Complex /* Typeable instance for Array[T] @@ -34,3 +35,13 @@ given iterableTypeable[T](using tt: Typeable[T]): Typeable[Array[T]] with /* TypeCase helpers to perform pattern matching on `Complex` higher kinded types */ val complexDoubleArray = TypeCase[Array[Complex[Double]]] val complexFloatArray = TypeCase[Array[Complex[Float]]] + +/* Type helper to describe inputs that accept Tensor or Real scalars */ +type TensorOrReal[D <: RealNN] = Tensor[D] | Real + +/* Evidence used in operations where Bool is accepted, but only on one of the two inputs, not both + */ +type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool] + +/* Evidence used in operations where at least one Float is required */ +type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN diff --git a/core/src/main/scala/torch/internal/NativeConverters.scala b/core/src/main/scala/torch/internal/NativeConverters.scala index 8e631eaf..c0c62e4a 100644 --- a/core/src/main/scala/torch/internal/NativeConverters.scala +++ b/core/src/main/scala/torch/internal/NativeConverters.scala @@ -41,10 +41,19 @@ private[torch] object NativeConverters: case i: Option[T] => i.map(f(_)).orNull case i: T => f(i) - def toOptional(l: Long | Option[Long]): LongOptional = toOptional(l, pytorch.LongOptional(_)) - def toOptional(l: Double | Option[Double]): DoubleOptional = + def toOptional(l: Long | Option[Long]): pytorch.LongOptional = + toOptional(l, pytorch.LongOptional(_)) + def toOptional(l: Double | Option[Double]): pytorch.DoubleOptional = toOptional(l, pytorch.DoubleOptional(_)) + def toOptional(l: Real | Option[Real]): pytorch.ScalarOptional = + toOptional( + l, + (r: Real) => + val scalar = toScalar(r) + pytorch.ScalarOptional(scalar) + ) + def toOptional[D <: DType](t: Tensor[D] | Option[Tensor[D]]): TensorOptional = toOptional(t, t => pytorch.TensorOptional(t.native)) diff --git a/core/src/main/scala/torch/special/package.scala b/core/src/main/scala/torch/special/package.scala index 588ebe0e..cc35b6fa 100644 --- a/core/src/main/scala/torch/special/package.scala +++ b/core/src/main/scala/torch/special/package.scala @@ -40,7 +40,7 @@ package object special: Tensor(torchNative.erfinv(input.native)) /** Computes the base two exponential function of `input`. */ - def exp2[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + def exp2[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.exp2(input.native)) /** Computes the exponential of the elements minus 1 of `input`. */ @@ -55,36 +55,38 @@ package object special: /** Computes the regularized lower incomplete gamma function */ // NOTE it is named `gammainc` in pytorch torch.special - def igamma[D <: DType, D2 <: DType]( + // TODO Change `D2 <: RealNN` once we fix property testing compilation + def igamma[D <: RealNN, D2 <: FloatNN]( input: Tensor[D], other: Tensor[D2] - ): Tensor[FloatPromoted[Promoted[D, D2]]] = + )(using AtLeastOneFloat[D, D2]): Tensor[FloatPromoted[Promoted[D, D2]]] = Tensor(torchNative.igamma(input.native, other.native)) /** Computes the regularized upper incomplete gamma function */ // NOTE it is named `gamaincc` in pytorch torch.special - def igammac[D <: DType, D2 <: DType]( + // TODO Change `D2 <: RealNN` once we fix property testing compilation + def igammac[D <: RealNN, D2 <: FloatNN]( input: Tensor[D], other: Tensor[D2] - ): Tensor[FloatPromoted[Promoted[D, D2]]] = + )(using AtLeastOneFloat[D, D2]): Tensor[FloatPromoted[Promoted[D, D2]]] = Tensor(torchNative.igammac(input.native, other.native)) /** Returns a new tensor with the logit of the elements of `input`. `input` is clamped to [eps, 1 * \- eps] when eps is not None. When eps is None and input < 0 or input > 1, the function will * yields NaN. */ - def logit[D <: DType](input: Tensor[D], eps: Option[Double]): Tensor[FloatPromoted[D]] = + def logit[D <: RealNN](input: Tensor[D], eps: Option[Double]): Tensor[FloatPromoted[D]] = Tensor(torchNative.logit(input.native, toOptional(eps))) /** Computes the multivariate log-gamma function with dimension p element-wise */ // NOTE it is named `multigammaln` in pytorch torch.special - def mvlgamma[D <: DType](input: Tensor[D], p: Int): Tensor[FloatPromoted[D]] = + def mvlgamma[D <: NumericRealNN](input: Tensor[D], p: Int): Tensor[FloatPromoted[D]] = Tensor(torchNative.mvlgamma(input.native, p)) /** Computes the nth derivative of the digamma function on `input`. n≥0 is called the order of the * polygamma function. */ - def polygamma[D <: DType](n: Int, input: Tensor[D]): Tensor[FloatPromoted[D]] = + def polygamma[D <: RealNN](n: Int, input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.polygamma(n, input.native)) /** Computes the expit (also known as the logistic sigmoid function) of the elements of `input`. @@ -94,5 +96,19 @@ package object special: Tensor(torchNative.sigmoid(input.native)) /** Returns a new tensor with the normalized sinc of the elements of `input`. */ - def sinc[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = + def sinc[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.sinc(input.native)) + + /** Computes `input * log(other)` with the following cases. */ + // TODO handle Scalar `input` + def xlogy[D <: RealNN, D2 <: RealNN]( + input: Tensor[D], + other: TensorOrReal[D2] + ): Tensor[FloatPromoted[D]] = + Tensor( + other match + case other: Tensor[D2] => + torchNative.xlogy(input.native, other.native) + case other: Real => + torchNative.xlogy(input.native, toScalar(other)) + ) diff --git a/core/src/main/scala/torch/torch.scala b/core/src/main/scala/torch/torch.scala index 5ca7f293..7ccd79b0 100644 --- a/core/src/main/scala/torch/torch.scala +++ b/core/src/main/scala/torch/torch.scala @@ -538,13 +538,16 @@ def stack[D <: DType](tensors: Seq[Tensor[D]], dim: Int = 0): Tensor[D] = Tensor // Pointwise Ops /** Computes the absolute value of each element in `input`. */ -def abs[D <: DType](input: Tensor[D]) = Tensor(torchNative.abs(input.native)) +def abs[D <: NumericNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.abs(input.native)) /** Computes the inverse cosine of each element in `input`. */ -def acos[D <: DType](input: Tensor[D]) = Tensor(torchNative.acos(input.native)) +def acos[D <: DType](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.acos(input.native)) /** Returns a new tensor with the inverse hyperbolic cosine of the elements of `input` . */ -def acosh[D <: DType](input: Tensor[D]) = Tensor(torchNative.acosh(input.native)) +def acosh[D <: DType](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.acosh(input.native)) /** Adds `other` to `input`. */ def add[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Promoted[D, D2]] = @@ -580,23 +583,23 @@ def addcmul[D <: DType, D2 <: DType, D3 <: DType]( Tensor(torchNative.addcmul(input.native, tensor1.native, tensor2.native, toScalar(value))) /** Computes the element-wise angle (in radians) of the given `input` tensor. */ -def angle[D <: DType](input: Tensor[D]): Tensor[ComplexToReal[D]] = +def angle[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[ComplexToReal[D]]] = Tensor(torchNative.angle(input.native)) /** Returns a new tensor with the arcsine of the elements of `input`. */ -def asin[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def asin[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.asin(input.native)) /** Returns a new tensor with the inverse hyperbolic sine of the elements of `input`. */ -def asinh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def asinh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.asinh(input.native)) /** Returns a new tensor with the arctangent of the elements of `input`. */ -def atan[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def atan[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.atan(input.native)) /** Returns a new tensor with the inverse hyperbolic tangent of the elements of `input`. */ -def atanh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def atanh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.atanh(input.native)) /** Element-wise arctangent of (input / other) with consideration of the quadrant. Returns a new @@ -604,7 +607,7 @@ def atanh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = * that other, the second parameter, is the x-coordinate, while input, the first parameter, is the * y-coordinate.) */ -def atan2[D <: DType, D2 <: DType]( +def atan2[D <: RealNN, D2 <: RealNN]( input: Tensor[D], other: Tensor[D2] ): Tensor[FloatPromoted[Promoted[D, D2]]] = @@ -618,25 +621,41 @@ def bitwiseNot[D <: BitwiseNN](input: Tensor[D]): Tensor[D] = /** Computes the bitwise AND of `input` and `other`. For bool tensors, it computes the logical AND. */ -def bitwiseAnd[D <: BitwiseNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = +def bitwiseAnd[D <: BitwiseNN, D2 <: BitwiseNN]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[Promoted[D, D2]] = Tensor(torchNative.bitwise_and(input.native, other.native)) /** Computes the bitwise OR of `input` and `other`. For bool tensors, it computes the logical OR. */ -def bitwiseOr[D <: BitwiseNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = +def bitwiseOr[D <: BitwiseNN, D2 <: BitwiseNN]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[Promoted[D, D2]] = Tensor(torchNative.bitwise_or(input.native, other.native)) /** Computes the bitwise XOR of `input` and `other`. For bool tensors, it computes the logical XOR. */ -def bitwiseXor[D <: BitwiseNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = +def bitwiseXor[D <: BitwiseNN, D2 <: BitwiseNN]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[Promoted[D, D2]] = Tensor(torchNative.bitwise_xor(input.native, other.native)) /** Computes the left arithmetic shift of `input` by `other` bits. */ -def bitwiseLeftShift[D <: IntNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = + +def bitwiseLeftShift[D <: BitwiseNN, D2 <: BitwiseNN]( + input: Tensor[D], + other: Tensor[D2] +)(using OnlyOneBool[D, D2]): Tensor[Promoted[D, D2]] = Tensor(torchNative.bitwise_left_shift(input.native, other.native)) /** Computes the right arithmetic s\hift of `input` by `other` bits. */ -def bitwiseRightShift[D <: IntNN](input: Tensor[D], other: Tensor[D]): Tensor[D] = +def bitwiseRightShift[D <: BitwiseNN, D2 <: BitwiseNN]( + input: Tensor[D], + other: Tensor[D2] +)(using OnlyOneBool[D, D2]): Tensor[Promoted[D, D2]] = Tensor(torchNative.bitwise_right_shift(input.native, other.native)) /** Returns a new tensor with the ceil of the elements of `input`, the smallest integer greater than @@ -649,32 +668,41 @@ def ceil[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = * and max, respectively, this returns: `min(max(input, min_value), max_value)` If min is None, * there is no lower bound. Or, if max is None there is no upper bound. */ -def clamp[D <: NumericNN]( +// TODO Support Tensor for min and max +def clamp[D <: RealNN]( input: Tensor[D], - min: Option[Tensor[D]], - max: Option[Tensor[D]] + min: Option[Real], + max: Option[Real] ): Tensor[D] = Tensor(torchNative.clamp(input.native, toOptional(min), toOptional(max))) /** Computes the element-wise conjugate of the given input tensor. If input has a non-complex dtype, * this function just returns input. */ -def conjPhysical[D <: NumericNN](input: Tensor[D]): Tensor[D] = +def conjPhysical[D <: DType](input: Tensor[D]): Tensor[D] = Tensor(torchNative.conj_physical(input.native)) /** Create a new floating-point tensor with the magnitude of input and the sign of other, * elementwise. */ -// TODO -// def copysign[D <: DType](input: Tensor[D]): Tensor[D] = -// Tensor(torchNative.copysign(input.native)) +def copysign[D <: RealNN, D2 <: RealNN]( + input: Tensor[D], + other: TensorOrReal[D2] +): Tensor[FloatPromoted[D]] = + Tensor( + other match + case other: Tensor[D2] => + torchNative.copysign(input.native, other.native) + case other: Real => + torchNative.copysign(input.native, toScalar(other)) + ) /** Returns a new tensor with the cosine of the elements of `input`. */ -def cos[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def cos[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.cos(input.native)) /** Returns a new tensor with the hyperbolic cosine of the elements of `input`. */ -def cosh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def cosh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.cosh(input.native)) /** Returns a new tensor with each of the elements of `input` converted from angles in degrees to @@ -684,16 +712,27 @@ def deg2rad[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.deg2rad(input.native)) /** Divides each element of the input `input` by the corresponding element of `other`. */ -def div[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[D] = +// TODO handle roundingMode + +def div[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[FloatPromoted[Promoted[D, D2]]] = Tensor(torchNative.div(input.native, other.native)) +def div[D <: DType, S <: ScalaType]( + input: Tensor[D], + other: S +): Tensor[FloatPromoted[Promoted[D, ScalaToDType[S]]]] = + Tensor(torchNative.div(input.native, toScalar(other))) + export torch.special.digamma export torch.special.erf export torch.special.erfc export torch.special.erfinv /** Returns a new tensor with the exponential of the elements of the input tensor `input`. */ -def exp[D <: RealNN](input: Tensor[D]): Tensor[D] = +def exp[D <: DType](input: Tensor[D]): Tensor[D] = Tensor(torchNative.exp(input.native)) export torch.special.exp2 @@ -702,7 +741,6 @@ export torch.special.expm1 /** Returns a new tensor with the data in `input` fake quantized per channel using `scale`, * `zero_point`, `quant_min` and `quant_max`, across the channel specified by `axis`. */ -// TODO Fix pytorch docs to add `axis` input def fakeQuantizePerChannelAffine( input: Tensor[Float32], scale: Tensor[Float32], @@ -753,7 +791,11 @@ def fakeQuantizePerTensorAffine( torchNative.fake_quantize_per_tensor_affine(input.native, scale, zeroPoint, quantMin, quantMax) ) -// TODO torch.fix // Alias for torch.trunc +/** Returns a new tensor with the truncated integer values of the elements of `input`. Alias for + * torch.trunc + */ +def fix[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = + Tensor(torchNative.fix(input.native)) /** Raises `input` to the power of `exponent`, elementwise, in double precision. If neither input is * complex returns a `torch.float64` tensor, and if one or more inputs is complex returns a @@ -784,30 +826,33 @@ def floor[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.floor(input.native)) /** Computes `input` divided by `other`, elementwise, and floors the result. */ -def floorDivide[D <: DType, D2 <: DType]( +def floorDivide[D <: RealNN, D2 <: RealNN]( input: Tensor[D], - other: TensorOrReal[D2] -): Tensor[Promoted[D, D2]] = - Tensor( - (input, other) match - case (input: Tensor[D], other: Tensor[D2]) => - torchNative.floor_divide(input.native, other.native) - case (input: Tensor[D], other: Real) => - torchNative.floor_divide(input.native, toScalar(other)) - ) + other: Tensor[D2] +)(using OnlyOneBool[D, D2]): Tensor[Promoted[D, D2]] = + Tensor(torchNative.floor_divide(input.native, other.native)) + +def floorDivide[D <: RealNN, R <: Real]( + input: Tensor[D], + other: R +)(using OnlyOneBool[D, ScalaToDType[R]]): Tensor[Promoted[D, ScalaToDType[R]]] = + Tensor(torchNative.floor_divide(input.native, toScalar(other))) /** Applies C++’s `std::fmod` entrywise. The result has the same sign as the dividend `input` and * its absolute value is less than that of `other`. */ // NOTE: When the divisor is zero, returns NaN for floating point dtypes on both CPU and GPU; raises RuntimeError for integer division by zero on CPU; Integer division by zero on GPU may return any value. -def fmod[D <: RealNN](input: Tensor[D], other: TensorOrReal[D]): Tensor[D] = - Tensor( - other match - case (other: Tensor[D]) => - torchNative.fmod(input.native, other.native) - case (other: Real) => - torchNative.fmod(input.native, toScalar(other)) - ) +def fmod[D <: RealNN, D2 <: RealNN]( + input: Tensor[D], + other: Tensor[D2] +)(using OnlyOneBool[D, D2]): Tensor[Promoted[D, D2]] = + Tensor(torchNative.fmod(input.native, other.native)) + +def fmod[D <: RealNN, S <: ScalaType]( + input: Tensor[D], + other: S +)(using OnlyOneBool[D, ScalaToDType[S]]): Tensor[Promoted[D, ScalaToDType[S]]] = + Tensor(torchNative.fmod(input.native, toScalar(other))) /** Computes the fractional portion of each element in `input`. */ def frac[D <: FloatNN](input: Tensor[D]): Tensor[D] = @@ -820,10 +865,12 @@ def frexp[D <: FloatNN](input: Tensor[D]): (Tensor[FloatPromoted[D]], Tensor[Int val nativeTuple = torchNative.frexp(input.native) (Tensor(nativeTuple.get0), new Int32Tensor(nativeTuple.get1)) -// TODO implement -/** */ -// def gradient[D <: DType](input: Tensor[D]): Tensor[D] = -// Tensor(torchNative.???) +/** Estimates the gradient of a function g:Rn → R in one or more dimensions using the second-order + * accurate central differences method. + */ +// TODO handle other spacing and dim invariants +// def gradient[D <: DType](input: Tensor[D], spacing: Float, dim: Option[Long], edgeOrder: Long = 1): Tensor[D] = +// Tensor(torchNative.gradient(input.native, toScalar(spacing), toOptional(dim), edgeOrder)) /** Returns a new tensor containing imaginary values of the `input` tensor. The returned tensor and * `input` share the same underlying storage. @@ -856,19 +903,19 @@ def lgamma[D <: RealNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.lgamma(input.native)) /** Returns a new tensor with the natural logarithm of the elements of `input`. */ -def log[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def log[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.log(input.native)) /** Returns a new tensor with the logarithm to the base 10 of the elements of `input`. */ -def log10[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def log10[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.log10(input.native)) /** Returns a new tensor with the natural logarithm of (1 + input). */ -def log1p[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def log1p[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.log1p(input.native)) /** Returns a new tensor with the logarithm to the base 2 of the elements of `input`. */ -def log2[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def log2[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.log2(input.native)) /** Logarithm of the sum of exponentiations of the inputs. Calculates pointwise log `log(e**x + @@ -878,7 +925,7 @@ def log2[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = * such a fashion. This op should be disambiguated with `torch.logsumexp()` which performs a * reduction on a single tensor. */ -def logaddexp[D <: DType, D2 <: DType]( +def logaddexp[D <: RealNN, D2 <: RealNN]( input: Tensor[D], other: Tensor[D2] ): Tensor[Promoted[D, D2]] = @@ -887,7 +934,7 @@ def logaddexp[D <: DType, D2 <: DType]( /** Logarithm of the sum of exponentiations of the inputs in base-2. Calculates pointwise `log2(2**x * + 2**y)`. See torch.logaddexp() for more details. */ -def logaddexp2[D <: DType, D2 <: DType]( +def logaddexp2[D <: RealNN, D2 <: RealNN]( input: Tensor[D], other: Tensor[D2] ): Tensor[Promoted[D, D2]] = @@ -899,11 +946,11 @@ def logaddexp2[D <: DType, D2 <: DType]( def logicalAnd[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Bool] = Tensor(torchNative.logical_and(input.native, other.native)) -/** Computes the element-wise logical NOT of the given input tensor. TODO If not specified, the - * output tensor will have the bool dtype. If the input tensor is not a bool tensor, zeros are - * treated as False and non-zeros are treated as True. +/** Computes the element-wise logical NOT of the given input tensor. If the input tensor is not a + * bool tensor, zeros are treated as False and non-zeros are treated as True. TODO If not + * specified, the output tensor will have the bool dtype. */ -def logicalNot[D <: RealNN](input: Tensor[D]): Tensor[Bool] = +def logicalNot[D <: DType](input: Tensor[D]): Tensor[Bool] = Tensor(torchNative.logical_not(input.native)) /** Computes the element-wise logical OR of the given input tensors. Zeros are treated as False and @@ -916,15 +963,16 @@ def logicalOr[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Ten * nonzeros are treated as True. */ def logicalXor[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Bool] = - Tensor(torchNative.logical_or(input.native, other.native)) + Tensor(torchNative.logical_xor(input.native, other.native)) export torch.special.logit /** Given the legs of a right triangle, return its hypotenuse. */ -def hypot[D <: DType, D2 <: DType]( +// TODO Change `D2 <: RealNN` once we fix property testing compilation +def hypot[D <: RealNN, D2 <: FloatNN]( input: Tensor[D], other: Tensor[D2] -): Tensor[FloatPromoted[Promoted[D, D2]]] = +)(using AtLeastOneFloat[D, D2]): Tensor[FloatPromoted[Promoted[D, D2]]] = Tensor(torchNative.hypot(input.native, other.native)) export torch.special.i0 @@ -942,11 +990,11 @@ export torch.special.mvlgamma * positive infinity is replaced with the greatest finite value representable by input’s dtype, and * negative infinity is replaced with the least finite value representable by input’s dtype. */ -def nanToNum[D <: FloatNN]( +def nanToNum[D <: RealNN]( input: Tensor[D], nan: Option[Double] = None, - posinf: Option[Double], - neginf: Option[Double] + posinf: Option[Double] = None, + neginf: Option[Double] = None ): Tensor[D] = Tensor( torchNative.nan_to_num(input.native, toOptional(nan), toOptional(posinf), toOptional(neginf)) @@ -957,10 +1005,11 @@ def neg[D <: NumericNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.neg(input.native)) /** Return the next floating-point value after `input` towards `other`, elementwise. */ -def nextafter[D <: DType, D2 <: DType]( +// TODO Change `D2 <: RealNN` once we fix property testing compilation +def nextafter[D <: RealNN, D2 <: FloatNN]( input: Tensor[D], other: Tensor[D2] -): Tensor[FloatPromoted[Promoted[D, D2]]] = +)(using AtLeastOneFloat[D, D2]): Tensor[FloatPromoted[Promoted[D, D2]]] = Tensor(torchNative.nextafter(input.native, other.native)) export torch.special.polygamma @@ -973,19 +1022,25 @@ def positive[D <: NumericNN](input: Tensor[D]): Tensor[D] = * `exponent` can be either a single float number or a Tensor with the same number of elements as * input. */ -// TODO handle Scalar `input` def pow[D <: DType, D2 <: DType]( input: Tensor[D], - exponent: TensorOrReal[D2] -): Tensor[FloatPromoted[D]] = - Tensor( - (input, exponent) match - case (input: Tensor[D], exponent: Tensor[D2]) => - torchNative.pow(input.native, exponent.native) - case (input: Tensor[D], exponent: Real) => - torchNative.pow(input.native, toScalar(exponent)) - ) + exponent: Tensor[D2] +)(using OnlyOneBool[D, D2]): Tensor[Promoted[D, D2]] = + Tensor(torchNative.pow(input.native, exponent.native)) +def pow[D <: DType, S <: ScalaType]( + input: Tensor[D], + exponent: S +)(using OnlyOneBool[D, ScalaToDType[S]]): Tensor[Promoted[D, ScalaToDType[S]]] = + Tensor(torchNative.pow(input.native, toScalar(exponent))) + +def pow[S <: ScalaType, D <: DType]( + input: S, + exponent: Tensor[D] +)(using OnlyOneBool[ScalaToDType[S], D]): Tensor[Promoted[ScalaToDType[S], D]] = + Tensor(torchNative.pow(toScalar(input), exponent.native)) + +// TODO Implement creation of QInts // TODO quantized_batch_norm // TODO quantized_max_pool1d // TODO quantized_max_pool2d @@ -1003,34 +1058,39 @@ def real[D <: DType](input: Tensor[D]): Tensor[ComplexToReal[D]] = Tensor(torchNative.real(input.native)) /** Returns a new tensor with the reciprocal of the elements of `input` */ -def reciprocal[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def reciprocal[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.reciprocal(input.native)) /** Computes Python’s modulus operation entrywise. The result has the same sign as the divisor * `other` and its absolute value is less than that of `other`. */ -// TODO handle Scalar `input` -def remainder[D <: DType, D2 <: DType]( +def remainder[D <: RealNN, D2 <: RealNN]( input: Tensor[D], - other: TensorOrReal[D2] -): Tensor[FloatPromoted[D]] = - Tensor( - (input, other) match - case (input: Tensor[D], other: Tensor[D2]) => - torchNative.remainder(input.native, other.native) - case (input: Tensor[D], other: Real) => - torchNative.remainder(input.native, toScalar(other)) - ) + other: Tensor[D2] +): Tensor[Promoted[D, D2]] = + Tensor(torchNative.remainder(input.native, other.native)) + +def remainder[D <: DType, R <: Real]( + input: Tensor[D], + other: R +): Tensor[Promoted[D, ScalaToDType[R]]] = + Tensor(torchNative.remainder(input.native, toScalar(other))) + +def remainder[D <: DType, R <: Real]( + input: R, + other: Tensor[D] +): Tensor[Promoted[ScalaToDType[R], D]] = + Tensor(torchNative.remainder(toScalar(input), other.native)) /** Rounds elements of `input` to the nearest integer. If decimals is negative, it specifies the * number of positions to the left of the decimal point. */ -def round[D <: NumericNN](input: Tensor[D], decimals: Long = 0): Tensor[D] = +def round[D <: FloatNN](input: Tensor[D], decimals: Long = 0): Tensor[D] = Tensor(torchNative.round(input.native, decimals)) /** Returns a new tensor with the reciprocal of the square-root of each of the elements of `input`. */ -def rsqrt[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def rsqrt[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.rsqrt(input.native)) export torch.special.sigmoid @@ -1052,60 +1112,72 @@ def signbit[D <: RealNN](input: Tensor[D]): Tensor[Bool] = Tensor(torchNative.signbit(input.native)) /** Returns a new tensor with the sine of the elements of `input`. */ -def sin[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def sin[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.sin(input.native)) export torch.special.sinc /** Returns a new tensor with the hyperbolic sine of the elements of `input`. */ -def sinh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def sinh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.sinh(input.native)) -// TODO softmax - export torch.nn.functional.softmax /** Returns a new tensor with the square-root of the elements of `input`. */ -def sqrt[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def sqrt[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.sqrt(input.native)) /** Returns a new tensor with the square of the elements of `input`. */ -def square[D <: RealNN](input: Tensor[D]): Tensor[NumericPromoted[D]] = +def square[D <: DType](input: Tensor[D]): Tensor[NumericPromoted[D]] = Tensor(torchNative.square(input.native)) /** Subtracts `other`, scaled by `alpha`, from `input`. */ -def sub[D <: DType, D2 <: DType](input: Tensor[D], other: Tensor[D2]): Tensor[Promoted[D, D2]] = +def sub[D <: NumericNN, D2 <: NumericNN]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[Promoted[D, D2]] = Tensor(torchNative.sub(input.native, other.native)) -def sub[D <: DType, D2 <: DType]( +def sub[D <: NumericNN, D2 <: NumericNN]( input: Tensor[D], other: Tensor[D2], alpha: ScalaType ): Tensor[Promoted[D, D2]] = Tensor(torchNative.sub(input.native, other.native, toScalar(alpha))) -def sub[D <: DType, S <: ScalaType]( +def sub[D <: NumericNN, D2 <: NumericNN]( input: Tensor[D], - other: S, + other: Numeric, alpha: ScalaType -): Tensor[Promoted[D, ScalaToDType[S]]] = +): Tensor[Promoted[D, D2]] = Tensor(torchNative.sub(input.native, toScalar(other), toScalar(alpha))) /** Returns a new tensor with the tangent of the elements of `input`. */ -def tan[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def tan[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.tan(input.native)) /** Returns a new tensor with the hyperbolic tangent of the elements of `input`. */ -def tanh[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] = +def tanh[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] = Tensor(torchNative.tanh(input.native)) -// TODO true_divide +/** Alias for `torch.div()` with `rounding_mode=None` */ +def trueDivide[D <: DType, D2 <: DType]( + input: Tensor[D], + other: Tensor[D2] +): Tensor[FloatPromoted[Promoted[D, D2]]] = + Tensor(torchNative.true_divide(input.native, other.native)) + +def trueDivide[D <: DType, S <: ScalaType]( + input: Tensor[D], + other: S +): Tensor[FloatPromoted[Promoted[D, ScalaToDType[S]]]] = + Tensor(torchNative.true_divide(input.native, toScalar(other))) /** Returns a new tensor with the truncated integer values of the elements of `input`. */ def trunc[D <: NumericRealNN](input: Tensor[D]): Tensor[D] = Tensor(torchNative.trunc(input.native)) -// TODO xlogy +export torch.special.xlogy // End Pointwise Ops diff --git a/core/src/test/scala/torch/Generators.scala b/core/src/test/scala/torch/Generators.scala index 953d75d8..170c240b 100644 --- a/core/src/test/scala/torch/Generators.scala +++ b/core/src/test/scala/torch/Generators.scala @@ -40,7 +40,7 @@ object Generators: int64, float32, float64, - complex32, + // complex32, // NOTE: A lot of CPU operations do not support this dtype yet complex64, complex128, bool, @@ -49,7 +49,7 @@ object Generators: // qint32, bfloat16 // quint4x2, - // float16, // NOTE: A lot of CPU do not support this dtype + // float16, // NOTE: A lot of CPU operations do not support this dtype yet // undefined, // numoptions ) diff --git a/core/src/test/scala/torch/TensorCheckSuite.scala b/core/src/test/scala/torch/TensorCheckSuite.scala new file mode 100644 index 00000000..de8af28c --- /dev/null +++ b/core/src/test/scala/torch/TensorCheckSuite.scala @@ -0,0 +1,193 @@ +/* + * Copyright 2022 storch.dev + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package torch + +import munit.ScalaCheckSuite +import shapeless3.typeable.{TypeCase, Typeable} +import shapeless3.typeable.syntax.typeable.* +import Generators.{*, given} +import org.scalacheck.Prop.* + +import scala.util.Try + +trait TensorCheckSuite extends ScalaCheckSuite { + + given tensorTypeable[T <: DType](using tt: Typeable[T]): Typeable[Tensor[T]] with + def castable(t: Any): Boolean = + t match + case (tensor: Tensor[?]) => + tensor.dtype.castable[T] + case _ => false + def describe = s"Tensor[${tt.describe}]" + + private def propertyTestName(opName: String) = s"${opName}.property-test" + private def unitTestName(opName: String) = s"${opName}.unit-test" + + inline def propertyTestBinaryOp[InA <: DType, InB <: DType]( + op: Function2[Tensor[InA], Tensor[InB], ?], + opName: String, + skipPropertyTestReason: Option[String] = None + ): Unit = + test(propertyTestName(opName)) { + assume(skipPropertyTestReason.isEmpty, skipPropertyTestReason) + + // TODO Validate output types + val tensorInACase = TypeCase[Tensor[InA]] + val tensorInBCase = TypeCase[Tensor[InB]] + forAll(genTensor, genTensor) { + case (tensorInACase(tensorA), tensorInBCase(tensorB)) => + val result = Try(op(tensorA, tensorB)) + assert( + result.isSuccess, + s"""| + |Tensor operation 'torch.${opName}' does not support (${tensorA.dtype}, ${tensorB.dtype}) inputs + | + |${result.failed.get} + """.stripMargin + ) + case (tensorA, tensorB) => + val result = Try(op(tensorA.asInstanceOf[Tensor[InA]], tensorB.asInstanceOf[Tensor[InB]])) + assert( + result.isFailure, + s"""| + |Tensor operation 'torch.${opName}' supports (A: ${tensorA.dtype}, B: ${tensorB.dtype}) inputs but storch interface is currently restricted to + |Type A: ${tensorInACase} + |Type B: ${tensorInBCase} + """.stripMargin + ) + } + } + + inline def propertyTestUnaryOp[In <: DType]( + op: Function1[Tensor[In], ?], + opName: String + ): Unit = + test(propertyTestName(opName)) { + // TODO Validate output types + val tensorInCase = TypeCase[Tensor[In]] + forAll(genTensor) { + case tensorInCase(tensor) => + val result = Try(op(tensor)) + assert( + result.isSuccess, + s"""| + |Tensor operation 'torch.${opName}' does not support ${tensor.dtype} inputs + | + |${result.failed.get} + """.stripMargin + ) + case tensor => + val result = Try(op(tensor.asInstanceOf[Tensor[In]])) + assert( + result.isFailure, + s"""| + |Tensor operation 'torch.${opName}' supports ${tensor.dtype} inputs but storch interface is currently restricted to ${tensorInCase} + """.stripMargin + ) + } + } + + inline def unitTestBinaryOp[ + InA <: DType, + InAS <: ScalaType, + InB <: DType, + InBS <: ScalaType + ]( + op: Function2[Tensor[InA], Tensor[InB], Tensor[?]], + opName: String, + inline inputTensors: (Tensor[ScalaToDType[InAS]], Tensor[ScalaToDType[InBS]]), + inline expectedTensor: Tensor[?], + absolutePrecision: Double = 1e-04 + )(using ScalaToDType[InAS] <:< InA, ScalaToDType[InBS] <:< InB): Unit = + test(unitTestName(opName)) { + val outputTensor = op( + inputTensors._1.asInstanceOf[Tensor[InA]], + inputTensors._2.asInstanceOf[Tensor[InB]] + ) + val allclose = outputTensor.allclose( + other = expectedTensor, + atol = absolutePrecision, + equalNan = true + ) + assert( + allclose, + s"""| + |Tensor results are not all close for 'torch.${opName}' + | + |Input tensors: + |${inputTensors} + | + |Output tensor: + |${outputTensor} + | + |Expected tensor: + |${expectedTensor}""".stripMargin + ) + } + + inline def unitTestUnaryOp[In <: DType, InS <: ScalaType]( + op: Function1[Tensor[In], Tensor[?]], + opName: String, + inline inputTensor: Tensor[ScalaToDType[InS]], + inline expectedTensor: Tensor[?], + absolutePrecision: Double = 1e-04 + )(using ScalaToDType[InS] <:< In): Unit = + test(unitTestName(opName)) { + val outputTensor = op(inputTensor.asInstanceOf[Tensor[In]]) + val allclose = outputTensor.allclose( + other = expectedTensor, + atol = absolutePrecision, + equalNan = true + ) + assert( + allclose, + s"""| + |Tensor results are not all close for 'torch.${opName}' + | + |Input tensor: + |${inputTensor} + | + |Output tensor: + |${outputTensor} + | + |Expected tensor: + |${expectedTensor}""".stripMargin + ) + } + + inline def testBinaryOp[InA <: DType, InAS <: ScalaType, InB <: DType, InBS <: ScalaType]( + op: Function2[Tensor[InA], Tensor[InB], Tensor[?]], + opName: String, + inline inputTensors: (Tensor[ScalaToDType[InAS]], Tensor[ScalaToDType[InBS]]), + inline expectedTensor: Tensor[?], + absolutePrecision: Double = 1e-04, + skipPropertyTestReason: Option[String] = None + )(using ScalaToDType[InAS] <:< InA, ScalaToDType[InBS] <:< InB): Unit = + propertyTestBinaryOp(op, opName, skipPropertyTestReason) + unitTestBinaryOp(op, opName, inputTensors, expectedTensor, absolutePrecision) + + inline def testUnaryOp[In <: DType, InS <: ScalaType]( + op: Function1[Tensor[In], Tensor[?]], + opName: String, + inline inputTensor: Tensor[ScalaToDType[InS]], + inline expectedTensor: Tensor[?], + absolutePrecision: Double = 1e-04 + )(using ScalaToDType[InS] <:< In): Unit = + propertyTestUnaryOp(op, opName) + unitTestUnaryOp(op, opName, inputTensor, expectedTensor, absolutePrecision) + +} diff --git a/core/src/test/scala/torch/TensorSuite.scala b/core/src/test/scala/torch/TensorSuite.scala index 4f8e7609..40e21c53 100644 --- a/core/src/test/scala/torch/TensorSuite.scala +++ b/core/src/test/scala/torch/TensorSuite.scala @@ -16,71 +16,11 @@ package torch -import DeviceType.CUDA - -import java.nio.{IntBuffer, LongBuffer} - -import munit.ScalaCheckSuite -import torch.DeviceType.CUDA import org.scalacheck.Prop.* -import org.bytedeco.pytorch.global.torch as torch_native -import org.scalacheck.{Arbitrary, Gen} -import org.scalacheck._ -import Gen._ -import Arbitrary.arbitrary -import DeviceType.CPU import Generators.{*, given} -import scala.util.Try import spire.math.Complex -import spire.implicits.DoubleAlgebra - -class TensorSuite extends ScalaCheckSuite { - - inline private def testUnaryOp[In <: DType, InS <: ScalaType]( - op: Tensor[In] => Tensor[?], - opName: String, - inline inputTensor: Tensor[ScalaToDType[InS]], - inline expectedTensor: Tensor[?], - absolutePrecision: Double = 1e-04 - )(using ScalaToDType[InS] <:< In): Unit = - val propertyTestName = s"${opName}.property-test" - test(propertyTestName) { - forAll(genTensor[In]) { (tensor) => - val result = Try(op(tensor)) - // TODO Validate output types - assert( - result.isSuccess, - s"""| - |Tensor operation 'torch.${opName}' does not support ${tensor.dtype} inputs - | - |${result.failed.get} - """.stripMargin - ) - } - } - val unitTestName = s"${opName}.unit-test" - test(unitTestName) { - val outputTensor = op(inputTensor.asInstanceOf[Tensor[In]]) - val allclose = outputTensor.allclose( - other = expectedTensor, - atol = absolutePrecision, - equalNan = true - ) - assert( - allclose, - s"""| - |Tensor results are not all close for 'torch.${opName}' - | - |Input tensor: - |${inputTensor} - | - |Output tensor: - |${outputTensor} - | - |Expected tensor: - |${expectedTensor}""".stripMargin - ) - } + +class TensorSuite extends TensorCheckSuite { test("arange") { val t0 = arange(0, 10) @@ -156,9 +96,60 @@ class TensorSuite extends ScalaCheckSuite { assert(t.grad.equal(torch.ones(Seq(3)))) } + test("indexing") { + val tensor = torch.arange(0, 16).reshape(4, 4) + // first row + assertEquals(tensor(0), Tensor(Seq(0, 1, 2, 3))) + // first column + assertEquals(tensor(torch.Slice(), 0), Tensor(Seq(0, 4, 8, 12))) + // last column + assertEquals(tensor(---, -1), Tensor(Seq(3, 7, 11, 15))) + } + + testUnaryOp( + op = abs, + opName = "abs", + inputTensor = Tensor(Seq(-1, -2, 3)), + expectedTensor = Tensor(Seq(1, 2, 3)) + ) + + testUnaryOp( + op = acos, + opName = "acos", + inputTensor = Tensor(Seq(0.3348, -0.5889, 0.2005, -0.1584)), + expectedTensor = Tensor(Seq(1.2294, 2.2004, 1.3690, 1.7298)) + ) + + testUnaryOp( + op = acosh, + opName = "acosh", + inputTensor = Tensor(Seq(1.3192, 1.9915, 1.9674, 1.7151)), + expectedTensor = Tensor(Seq(0.7791, 1.3120, 1.2979, 1.1341)) + ) + + testUnaryOp( + op = acosh, + opName = "acosh", + inputTensor = Tensor(Seq(1.3192, 1.9915, 1.9674, 1.7151)), + expectedTensor = Tensor(Seq(0.7791, 1.3120, 1.2979, 1.1341)) + ) + + testUnaryOp( + op = add(_, other = 20), + opName = "add", + inputTensor = Tensor(Seq(0.0202, 1.0985, 1.3506, -0.6056)), + expectedTensor = Tensor(Seq(20.0202, 21.0985, 21.3506, 19.3944)) + ) + // TODO addcdiv // TODO addcmul - // TODO angle + + testUnaryOp( + op = angle, + opName = "angle", + inputTensor = Tensor(Seq(Complex(-1.0, 1.0), Complex(-2.0, 2.0), Complex(3.0, -3.0))), + expectedTensor = Tensor(Seq(2.3562, 2.3562, -0.7854)) + ) testUnaryOp( op = asin, @@ -185,10 +176,18 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(-1.7253, 0.3060, -1.2899, -0.1893)) ) - // TODO atan2 + testBinaryOp( + op = atan2, + opName = "atan2", + inputTensors = ( + Tensor(Seq(0.9041, 0.0196, -0.3108, -2.4423)), + Tensor(Seq(1.3104, -1.5804, 0.6674, 0.7710)) + ), + expectedTensor = Tensor(Seq(0.6039, 3.1292, -0.4358, -1.2650)) + ) + + // TODO Test boolean cases for bitwise operations - // TODO Test boolean cases for bitwise_not - // https://pytorch.org/docs/stable/generated/torch.bitwise_not.html testUnaryOp( op = bitwiseNot, opName = "bitwiseNot", @@ -196,11 +195,57 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(0, 1, -4)) ) - // TODO bitwise_and - // TODO bitwise_or - // TODO bitwise_xor - // TODO bitwise_left_shift - // TODO bitwise_right_shift + testBinaryOp( + op = bitwiseAnd, + opName = "bitwiseAnd", + inputTensors = ( + Tensor(Seq(-1, -2, 3)), + Tensor(Seq(1, 0, 3)) + ), + expectedTensor = Tensor(Seq(1, 0, 3)) + ) + + testBinaryOp( + op = bitwiseOr, + opName = "bitwiseOr", + inputTensors = ( + Tensor(Seq(-1, -2, 3)), + Tensor(Seq(1, 0, 3)) + ), + expectedTensor = Tensor(Seq(-1, -2, 3)) + ) + + testBinaryOp( + op = bitwiseXor, + opName = "bitwiseXor", + inputTensors = ( + Tensor(Seq(-1, -2, 3)), + Tensor(Seq(1, 0, 3)) + ), + expectedTensor = Tensor(Seq(-2, -2, 0)) + ) + + // TODO Enable property test once we figure out to consider OnlyOneBool evidence in genDType + unitTestBinaryOp( + op = bitwiseLeftShift, + opName = "bitwiseLeftShift", + inputTensors = ( + Tensor(Seq(-1, -2, 3)), + Tensor(Seq(1, 0, 3)) + ), + expectedTensor = Tensor(Seq(-2, -2, 24)) + ) + + // TODO Enable property test once we figure out to consider OnlyOneBool evidence in genDType + unitTestBinaryOp( + op = bitwiseRightShift, + opName = "bitwiseRightShift", + inputTensors = ( + Tensor(Seq(-2, -7, 31)), + Tensor(Seq(1, 0, 3)) + ), + expectedTensor = Tensor(Seq(-1, -7, 3)) + ) testUnaryOp( op = ceil, @@ -209,7 +254,13 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(-0.0, -1.0, -1.0, 1.0)) ) - // TODO clamp + // TODO test min max inputs + testUnaryOp( + op = clamp(_, min = Some(-0.5), max = Some(0.5)), + opName = "clamp", + inputTensor = Tensor(Seq(-1.7120, 0.1734, -0.0478, -0.0922)), + expectedTensor = Tensor(Seq(-0.5, 0.1734, -0.0478, -0.0922)) + ) testUnaryOp( op = conjPhysical, @@ -218,7 +269,15 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(Complex(-1.0, -1.0), Complex(-2.0, -2.0), Complex(3.0, 3.0))) ) - // TODO copysign + testBinaryOp( + op = copysign, + opName = "copysign", + inputTensors = ( + Tensor(Seq(0.7079, 0.2778, -1.0249, 0.5719)), + Tensor(Seq(0.2373, 0.3120, 0.3190, -1.1128)) + ), + expectedTensor = Tensor(Seq(0.7079, 0.2778, 1.0249, -0.5719)) + ) testUnaryOp( op = cos, @@ -241,7 +300,15 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(3.1416, -3.1416, 6.2832, -6.2832, 1.5708, -1.5708)) ) - // TODO div + testBinaryOp( + op = div, + opName = "div", + inputTensors = ( + Tensor(Seq(-0.3711, -1.9353, -0.4605, -0.2917)), + Tensor(Seq(0.8032, 0.2930, -0.8113, -0.2308)) + ), + expectedTensor = Tensor(Seq(-0.4620, -6.6051, 0.5676, 1.2639)) + ) testUnaryOp( op = digamma, @@ -294,7 +361,23 @@ class TensorSuite extends ScalaCheckSuite { // TODO fakeQuantizePerChannelAffine // TODO fakeQuantizePerTensorAffine - // TODO floatPower + + testUnaryOp( + op = fix, + opName = "fix", + inputTensor = Tensor(Seq(3.4742, 0.5466, -0.8008, -0.9079)), + expectedTensor = Tensor(Seq(3.0, 0.0, -0.0, -0.0)) + ) + + testBinaryOp( + op = floatPower, + opName = "floatPower", + inputTensors = ( + Tensor(Seq(1, 2, 3, 4)), + Tensor(Seq(2, -3, 4, -5)) + ), + expectedTensor = Tensor(Seq(1.0, 0.125, 81.0, 9.7656e-4)) + ) testUnaryOp( op = floor, @@ -303,8 +386,27 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(-1.0, 1.0, -1.0, -1.0)) ) - // TODO floorDivide - // TODO fmod + // TODO Enable property test once we figure out to consider OnlyOneBool evidence in genDType + unitTestBinaryOp( + op = floorDivide, + opName = "floorDivide", + inputTensors = ( + Tensor(Seq(4.0, 3.0)), + Tensor(Seq(2.0, 2.0)) + ), + expectedTensor = Tensor(Seq(2.0, 1.0)) + ) + + // TODO Enable property test once we figure out to consider OnlyOneBool evidence in genDType + unitTestBinaryOp( + op = fmod, + opName = "fmod", + inputTensors = ( + Tensor(Seq(-3.0, -2.0, -1.0, 1.0, 2.0, 3.0)), + Tensor(Seq(2.0, 2.0, 2.0, 2.0, 2.0, 2.0)) + ), + expectedTensor = Tensor(Seq(-1.0, -0.0, -1.0, 1.0, 0.0, 1.0)) + ) testUnaryOp( op = frac, @@ -338,8 +440,36 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(0.3553, -0.7896, -0.0633, -0.8119)) ) - // TODO ldexp - // TODO lerp + testBinaryOp( + op = ldexp, + opName = "ldexp", + inputTensors = ( + Tensor(Seq(1.0)), + Tensor(Seq(1, 2, 3, 4)) + ), + expectedTensor = Tensor(Seq(2.0, 4.0, 8.0, 16.0)) + ) + + // TODO Test weight as tensor + // TODO Lerp must accepts the same type so we wrap this for generators to work properly + // testBinaryOp( + // op = lerp(_, _, weight = 0.5), + // opName = "lerp", + // inputTensors = ( + // Tensor(Seq(1.0, 2.0, 3.0, 4.0)), + // Tensor(Seq(10.0, 10.0, 10.0, 10.0)) + // ), + // expectedTensor = Tensor(Seq(5.5, 6.0, 6.5, 7.0)) + // ) + unitTestBinaryOp( + op = lerp(_, _, weight = 0.5), + opName = "lerp", + inputTensors = ( + Tensor(Seq(1.0, 2.0, 3.0, 4.0)), + Tensor(Seq(10.0, 10.0, 10.0, 10.0)) + ), + expectedTensor = Tensor(Seq(5.5, 6.0, 6.5, 7.0)) + ) testUnaryOp( op = lgamma, @@ -378,12 +508,40 @@ class TensorSuite extends ScalaCheckSuite { absolutePrecision = 1e-2 ) - // TODO logaddexp - // TODO logaddexp2 - // TODO logicalAnd + // TODO Enable property test once we figure out to consider OnlyOneBool evidence in genDType + unitTestBinaryOp( + op = logaddexp, + opName = "logaddexp", + inputTensors = ( + Tensor(Seq(-100.0, -200.0, -300.0)), + Tensor(Seq(-1.0, -2.0, -3.0)) + ), + expectedTensor = Tensor(Seq(-1.0, -2.0, -3.0)) + ) + + // TODO Enable property test once we figure out to consider OnlyOneBool evidence in genDType + unitTestBinaryOp( + op = logaddexp2, + opName = "logaddexp2", + inputTensors = ( + Tensor(Seq(-100.0, -200.0, -300.0)), + Tensor(Seq(-1.0, -2.0, -3.0)) + ), + expectedTensor = Tensor(Seq(-1.0, -2.0, -3.0)) + ) + + // TODO Test int32 tensors + testBinaryOp( + op = logicalAnd, + opName = "logicalAnd", + inputTensors = ( + Tensor(Seq(true, false, true)), + Tensor(Seq(true, false, false)) + ), + expectedTensor = Tensor(Seq(true, false, false)) + ) - // TODO Handle numeric cases for logical_not - // https://pytorch.org/docs/stable/generated/torch.logical_not.html + // TODO Test int32 tensors testUnaryOp( op = logicalNot, opName = "logicalNot", @@ -391,10 +549,43 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(false, true)) ) - // TODO logicalOr - // TODO logicalXor - // TODO logit - // TODO hypot + // TODO Test int32 tensors + testBinaryOp( + op = logicalOr, + opName = "logicalOr", + inputTensors = ( + Tensor(Seq(true, false, true)), + Tensor(Seq(true, false, false)) + ), + expectedTensor = Tensor(Seq(true, false, true)) + ) + + // TODO Test int32 tensors + testBinaryOp( + op = logicalXor, + opName = "logicalXor", + inputTensors = ( + Tensor(Seq(true, false, true)), + Tensor(Seq(true, false, false)) + ), + expectedTensor = Tensor(Seq(false, false, true)) + ) + + testUnaryOp( + op = logit(_, Some(1e-6)), + opName = "logit", + inputTensor = Tensor(Seq(0.2796, 0.9331, 0.6486, 0.1523, 0.6516)), + expectedTensor = Tensor(Seq(-0.9466, 2.6352, 0.6131, -1.7169, 0.6261)), + absolutePrecision = 1e-3 + ) + + // TODO Enable property test once we figure out to compile properly with AtLeastOneFloat + unitTestBinaryOp( + op = hypot, + opName = "hypot", + inputTensors = (Tensor(Seq(4.0)), Tensor(Seq(3.0, 4.0, 5.0))), + expectedTensor = Tensor(Seq(5.0, 5.6569, 6.4031)) + ) testUnaryOp( op = i0, @@ -403,11 +594,53 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(1.0, 1.2661, 2.2796, 4.8808, 11.3019)) ) - // TODO igamma - // TODO igammac - // TODO mul - // TODO mvlgamma - // TODO nanToNum + // TODO Enable property test once we figure out to compile properly with AtLeastOneFloat + unitTestBinaryOp( + op = igamma, + opName = "igamma", + inputTensors = ( + Tensor(Seq(4.0)), + Tensor(Seq(3.0, 4.0, 5.0)) + ), + expectedTensor = Tensor(Seq(0.3528, 0.5665, 0.7350)) + ) + + // TODO Enable property test once we figure out to compile properly with AtLeastOneFloat + unitTestBinaryOp( + op = igammac, + opName = "igammac", + inputTensors = ( + Tensor(Seq(4.0)), + Tensor(Seq(3.0, 4.0, 5.0)) + ), + expectedTensor = Tensor(Seq(0.6472, 0.4335, 0.2650)) + ) + + testBinaryOp( + op = mul, + opName = "mul", + inputTensors = ( + Tensor(Seq(1.1207)), + Tensor(Seq(0.5146, 0.1216, -0.5244, 2.2382)) + ), + expectedTensor = Tensor(Seq(0.5767, 0.1363, -0.5877, 2.5083)) + ) + + testUnaryOp( + op = mvlgamma(_, p = 2), + opName = "mvlgamma", + inputTensor = Tensor(Seq(1.6835, 1.8474, 1.1929)), + expectedTensor = Tensor(Seq(0.3928, 0.4007, 0.7586)) + ) + + // TODO Test nan, posinf, neginf arguments + // TODO Test float32 + testUnaryOp( + op = nanToNum(_, nan = None, posinf = None, neginf = None), + opName = "nanToNum", + inputTensor = Tensor(Seq(Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity, 3.14)), + expectedTensor = Tensor(Seq(0.0, 1.7976931348623157e308, -1.7976931348623157e308, 3.14)) + ) testUnaryOp( op = neg, @@ -416,8 +649,26 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(-0.0090, 0.2262, 0.0682, 0.2866, -0.3940)) ) - // TODO nextafter - // TODO polygamma + // TODO Enable property test once we figure out to compile properly with AtLeastOneFloat + // TODO Fix this unit test, as is not really significant due to fp precision + unitTestBinaryOp( + op = nextafter, + opName = "nextafter", + inputTensors = ( + Tensor(Seq(1.0, 2.0)), + Tensor(Seq(2.0, 1.0)) + ), + expectedTensor = Tensor(Seq(1.0, 2.0)), + absolutePrecision = 1e-8 + ) + + // TODO Test multiple values of `n` + testUnaryOp( + op = polygamma(1, _), + opName = "polygamma", + inputTensor = Tensor(Seq(1.0, 0.5)), + expectedTensor = Tensor(Seq(1.64493, 4.9348)) + ) testUnaryOp( op = positive, @@ -426,7 +677,18 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(0.0090, -0.2262, -0.0682, -0.2866, 0.3940)) ) - // TODO pow + // TODO Test scalar exponent + // TODO Enable property test once we figure out to consider OnlyOneBool evidence in genDType + unitTestBinaryOp( + op = pow, + opName = "pow", + inputTensors = ( + Tensor(Seq(1.0, 2.0, 3.0, 4.0)), + Tensor(Seq(1.0, 2.0, 3.0, 4.0)) + ), + expectedTensor = Tensor(Seq(1.0, 4.0, 27.0, 256.0)) + ) + // TODO quantized_batch_norm // TODO quantized_max_pool1d // TODO quantized_max_pool2d @@ -459,8 +721,34 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(-2.1763, -0.4713, -0.6986, 1.3702)) ) - // TODO remainder - // TODO round + // TODO Enable property test once we figure out to consider OnlyOneBool evidence in genDType + // propertyTestBinaryOp(remainder, "remainder") + test("remainder.unit-test") { + val result = remainder(Tensor(Seq(-3.0, -2.0, -1.0, 1.0, 2.0, 3.0)), 2) + val expected = Tensor(Seq(1.0, 0.0, 1.0, 1.0, 0.0, 1.0)) + assert(allclose(result, expected)) + + val result2 = remainder(-1.5, Tensor(Seq(1, 2, 3, 4, 5))).to(dtype = float64) + val expected2 = Tensor(Seq(0.5, 0.5, 1.5, 2.5, 3.5)) + assert(allclose(result2, expected2)) + + val result3 = remainder(Tensor(Seq(1, 2, 3, 4, 5)), Tensor(Seq(1, 2, 3, 4, 5))) + val expected3 = Tensor(Seq(0, 0, 0, 0, 0)) + println(expected3) + assert(allclose(result3, expected3)) + } + + testUnaryOp( + op = round(_, decimals = 0), + opName = "round", + inputTensor = Tensor(Seq(4.7, -2.3, 9.1, -7.7)), + expectedTensor = Tensor(Seq(5.0, -2.0, 9.0, -8.0)) + ) + test("round.unit-test.decimals") { + val input = Tensor(Seq(0.1234567)) + val result = round(input, decimals = 3) + assert(allclose(result, Tensor(Seq(0.123)), atol = 1e-3)) + } testUnaryOp( op = rsqrt, @@ -470,7 +758,12 @@ class TensorSuite extends ScalaCheckSuite { absolutePrecision = 1e-3 ) - // TODO sigmoid + testUnaryOp( + op = sigmoid, + opName = "sigmoid", + inputTensor = Tensor(Seq(0.9213, 1.0887, -0.8858, -1.7683)), + expectedTensor = Tensor(Seq(0.7153, 0.7481, 0.2920, 0.1458)) + ) testUnaryOp( op = sign, @@ -531,12 +824,18 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(4.3077, 1.0457, 0.0069, 0.2310)) ) - test("sub") { + testBinaryOp( + op = sub, + opName = "sub", + inputTensors = ( + Tensor(Seq(1, 2)), + Tensor(Seq(0, 1)) + ), + expectedTensor = Tensor(Seq(1, 1)) + ) + test("sub.unit-test.alpha") { val a = Tensor(Seq(1, 2)) val b = Tensor(Seq(0, 1)) - val res = sub(a, b) - assertEquals(res, Tensor(Seq(1, 1))) - val resAlpha = sub(a, b, alpha = 2) assertEquals( resAlpha, @@ -559,6 +858,16 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(0.7156, -0.6218, 0.8257, 0.2553)) ) + testBinaryOp( + op = trueDivide, + opName = "trueDivide", + inputTensors = ( + Tensor(Seq(-0.3711, -1.9353, -0.4605, -0.2917)), + Tensor(Seq(0.8032, 0.2930, -0.8113, -0.2308)) + ), + expectedTensor = Tensor(Seq(-0.4620, -6.6051, 0.5676, 1.2639)) + ) + testUnaryOp( op = trunc, opName = "trunc", @@ -566,13 +875,13 @@ class TensorSuite extends ScalaCheckSuite { expectedTensor = Tensor(Seq(3.0, 0.0, -0.0, -0.0)) ) - test("indexing") { - val tensor = torch.arange(0, 16).reshape(4, 4) - // first row - assertEquals(tensor(0), Tensor(Seq(0, 1, 2, 3))) - // first column - assertEquals(tensor(torch.Slice(), 0), Tensor(Seq(0, 4, 8, 12))) - // last column - assertEquals(tensor(---, -1), Tensor(Seq(3, 7, 11, 15))) - } + testBinaryOp( + op = xlogy, + opName = "xlogy", + inputTensors = ( + Tensor(Seq(0, 0, 0, 0, 0)), + Tensor(Seq(-1.0, 0.0, 1.0, Double.PositiveInfinity, Double.NaN)) + ), + expectedTensor = Tensor(Seq(0.0, 0.0, 0.0, 0.0, Double.NaN)) + ) }