-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from davoclavo/tensor-ops
Tensor Pointwise ops
- Loading branch information
Showing
17 changed files
with
2,171 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
watch_file devenv.nix | ||
watch_file devenv.yaml | ||
watch_file devenv.lock | ||
eval "$(devenv print-dev-env)" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,8 @@ metals.sbt | |
*.worksheet.sc | ||
/data/ | ||
.scala-build/ | ||
|
||
# Devenv | ||
.devenv* | ||
devenv.local.nix | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Contributing to Storch | ||
|
||
|
||
## Install dependencies via nix + devenv | ||
|
||
1. Install [nix](https://nixos.org) package manager | ||
|
||
```bash | ||
sh <(curl -L https://nixos.org/nix/install) --daemon | ||
``` | ||
|
||
For more info, see https://nixos.org/download.html | ||
|
||
2. Install [devenv](https://devenv.sh) | ||
|
||
```bash | ||
nix profile install --accept-flake-config github:cachix/devenv/latest | ||
``` | ||
|
||
For more info, see: https://devenv.sh/getting-started/#installation | ||
|
||
3. [Optionally] Install [direnv](https://direnv.net/) | ||
|
||
This will load the specific environment variables upon `cd` into the storch folder | ||
|
||
```bash | ||
nix profile install 'nixpkgs#direnv' | ||
``` | ||
|
||
4. Load environment | ||
|
||
If you did not install direnv, run the following in the `storch` root folder: | ||
|
||
```bash | ||
devenv shell | ||
``` | ||
|
||
If you installed direnv, just `cd` into storch | ||
|
||
|
||
## Linting | ||
|
||
### Manually run headerCrate + scalafmt on all files | ||
|
||
```bash | ||
sbt 'headerCreateAll ; scalafmtAll' | ||
``` | ||
|
||
### Add useful git pre-push linting checks | ||
|
||
```bash | ||
cp git-hooks/pre-push-checks .git/hooks/ && chmod +x git-hooks/pre-push-checks | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* Copyright 2022 storch.dev | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package torch | ||
|
||
import shapeless3.typeable.{TypeCase, Typeable} | ||
import shapeless3.typeable.syntax.typeable.* | ||
import scala.util.NotGiven | ||
import spire.math.Complex | ||
|
||
/* Typeable instance for Array[T] | ||
* NOTE: It needs to iterate through the whole array to validate casteability | ||
*/ | ||
given iterableTypeable[T](using tt: Typeable[T]): Typeable[Array[T]] with | ||
def castable(t: Any): Boolean = | ||
t match | ||
case (arr: Array[?]) => | ||
arr.forall(_.castable[T]) | ||
case _ => false | ||
def describe = s"Array[${tt.describe}]" | ||
|
||
/* TypeCase helpers to perform pattern matching on `Complex` higher kinded types */ | ||
val complexDoubleArray = TypeCase[Array[Complex[Double]]] | ||
val complexFloatArray = TypeCase[Array[Complex[Float]]] | ||
|
||
/* Type helper to describe inputs that accept Tensor or Real scalars */ | ||
type TensorOrReal[D <: RealNN] = Tensor[D] | Real | ||
|
||
/* Evidence used in operations where Bool is accepted, but only on one of the two inputs, not both | ||
*/ | ||
type OnlyOneBool[A <: DType, B <: DType] = NotGiven[A =:= Bool & B =:= Bool] | ||
|
||
/* Evidence used in operations where at least one Float is required */ | ||
type AtLeastOneFloat[A <: DType, B <: DType] = A <:< FloatNN | B <:< FloatNN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.