Skip to content

Commit

Permalink
Remove workarounds for cusolver loading issue
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk authored and Sören Brunk committed Jul 26, 2023
1 parent 624bd0f commit e2c2825
Show file tree
Hide file tree
Showing 14 changed files with 15 additions and 75 deletions.
7 changes: 3 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ val openblasVersion = "0.3.23"
val mklVersion = "2023.1"
ThisBuild / scalaVersion := "3.3.0"
ThisBuild / javaCppVersion := "1.5.10-SNAPSHOT"
ThisBuild / resolvers ++= Resolver.sonatypeOssRepos("snapshots")

ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11"))

Expand All @@ -40,8 +41,7 @@ ThisBuild / enableGPU := false
lazy val commonSettings = Seq(
Compile / doc / scalacOptions ++= Seq("-groups", "-snippet-compiler:compile"),
javaCppVersion := (ThisBuild / javaCppVersion).value,
javaCppPlatform := Seq(),
resolvers ++= Resolver.sonatypeOssRepos("snapshots")
javaCppPlatform := Seq()
// This is a hack to avoid depending on the native libs when publishing
// but conveniently have them on the classpath during development.
// There's probably a cleaner way to do this.
Expand Down Expand Up @@ -75,8 +75,7 @@ lazy val core = project
(if (enableGPU.value) "pytorch-gpu" else "pytorch") -> pytorchVersion,
"mkl" -> mklVersion,
"openblas" -> openblasVersion
// TODO remove cuda (not cuda-redist) once https://github.com/bytedeco/javacpp-presets/issues/1376 is fixed
) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion, "cuda" -> cudaVersion) else Seq()),
) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion) else Seq()),
javaCppPlatform := org.bytedeco.sbt.javacpp.Platform.current,
fork := true,
Test / fork := true,
Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ import spire.math.{Complex, UByte}
import scala.reflect.Typeable
import internal.NativeConverters
import internal.NativeConverters.toArray
import internal.LoadCusolver
import Device.CPU
import Layout.Strided
import org.bytedeco.pytorch.ByteArrayRef
Expand Down Expand Up @@ -802,7 +801,6 @@ type IntTensor = UInt8Tensor | Int8Tensor | Int16Tensor | Int32Tensor | Int64Ten
type ComplexTensor = Complex32Tensor | Complex64Tensor | Complex128Tensor

object Tensor:
LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

def apply[D <: DType](native: pytorch.Tensor): Tensor[D] = (native.scalar_type().intern() match
case ScalarType.Byte => new UInt8Tensor(native)
Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/torch/cuda/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
package torch

import org.bytedeco.pytorch.global.torch as torchNative
import torch.internal.LoadCusolver

/** This package adds support for CUDA tensor types, that implement the same function as CPU
* tensors, but they utilize GPUs for computation.
*/
package object cuda {
LoadCusolver

/** Returns a Boolean indicating if CUDA is currently available. */
def isAvailable: Boolean = torchNative.cuda_is_available()
Expand Down
31 changes: 0 additions & 31 deletions core/src/main/scala/torch/internal/LoadCusolver.scala

This file was deleted.

3 changes: 0 additions & 3 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,9 @@ import org.bytedeco.pytorch.GenericDictIterator
import spire.math.Complex
import spire.math.UByte
import scala.annotation.targetName
import internal.LoadCusolver

private[torch] object NativeConverters:

LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

inline def convertToOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match
case i: Option[T] => i.map(f(_)).orNull
case i: T => f(i)
Expand Down
5 changes: 1 addition & 4 deletions core/src/main/scala/torch/nn/functional/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package torch
package nn

import functional.*
import torch.internal.LoadCusolver

/** @groupname nn_conv Convolution functions
* @groupname nn_pooling Pooling functions
Expand All @@ -38,6 +37,4 @@ package object functional
with Linear
with Loss
with Pooling
with Sparse {
LoadCusolver
}
with Sparse
2 changes: 0 additions & 2 deletions core/src/main/scala/torch/nn/init.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ import org.bytedeco.pytorch.kSigmoid
import org.bytedeco.pytorch.kReLU
import org.bytedeco.pytorch.kLeakyReLU
import org.bytedeco.pytorch.Scalar
import torch.internal.LoadCusolver

// TODO implement remaining init functions
object init:
LoadCusolver
def kaimingNormal_(
t: Tensor[?],
a: Double = 0,
Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/torch/nn/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@

package torch

import torch.internal.LoadCusolver

/** These are the basic building blocks for graphs.
*
* @groupname nn_conv Convolution Layers
* @groupname nn_linear Linear Layers
* @groupname nn_utilities Utilities
*/
package object nn {
LoadCusolver

export modules.Module
export modules.Default
Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/torch/nn/utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ package nn

import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch.TensorVector
import torch.internal.LoadCusolver

object utils:
LoadCusolver
def clipGradNorm_(
parameters: Seq[Tensor[?]],
max_norm: Double,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ package ops
import internal.NativeConverters.*

import org.bytedeco.pytorch.global.torch as torchNative
import org.bytedeco.pytorch.{TensorArrayRef, TensorVector}
import org.bytedeco.pytorch.TensorArrayRef
import org.bytedeco.pytorch.TensorVector

/** Indexing, Slicing, Joining, Mutating Ops
*
* https://pytorch.org/docs/stable/torch.html#indexing-slicing-joining-mutating-ops
*/
private[torch] trait IndexingSlicingJoiningOps {

private def toArrayRef(tensors: Seq[Tensor[?]]): TensorArrayRef =
new TensorArrayRef(new TensorVector(tensors.map(_.native)*))

/** Returns a view of the tensor conjugated and with the last two dimensions transposed.
*
* `x.adjoint()` is equivalent to `x.transpose(-2, -1).conj()` for complex tensors and to
Expand Down
4 changes: 0 additions & 4 deletions core/src/main/scala/torch/ops/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ import org.bytedeco.pytorch.{MemoryFormatOptional, TensorArrayRef, TensorVector}

package object ops {

private[torch] def toArrayRef(tensors: Seq[Tensor[?]]): TensorArrayRef =
val vector = new TensorVector(tensors.map(_.native) *)
new TensorArrayRef(vector.front(), vector.size())

private[torch] def xLike[D <: DType, D2 <: DType | Derive](
input: Tensor[D],
dtype: D2,
Expand Down
3 changes: 0 additions & 3 deletions core/src/main/scala/torch/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

import torch.internal.LoadCusolver

import scala.util.Using

/** The torch package contains data structures for multi-dimensional tensors and defines
Expand All @@ -37,7 +35,6 @@ package object torch
with ops.PointwiseOps
with ops.RandomSamplingOps
with ops.ReductionOps {
LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

/** Disable gradient calculation for [[op]].
*
Expand Down
18 changes: 6 additions & 12 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ libraryDependencies += Seq(
```scala
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:@VERSION@"
```

Expand Down Expand Up @@ -66,6 +67,7 @@ fork := true
```scala
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:@VERSION@"
//> using lib "org.bytedeco:pytorch-platform:@PYTORCH_VERSION@-@JAVACPP_VERSION@"
```
Expand Down Expand Up @@ -101,6 +103,7 @@ fork := true
```scala
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:@VERSION@"
//> using lib "org.bytedeco:openblas:@OPENBLAS_VERSION@-@JAVACPP_VERSION@,classifier=linux-x86_64"
//> using lib "org.bytedeco:pytorch:@PYTORCH_VERSION@-@JAVACPP_VERSION@,classifier=linux-x86_64"
Expand Down Expand Up @@ -159,6 +162,7 @@ fork := true
```scala
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:@VERSION@"
//> using lib "org.bytedeco:pytorch-platform-gpu:@PYTORCH_VERSION@-@JAVACPP_VERSION@"
//> using lib "org.bytedeco:cuda-platform-redist:@CUDA_VERSION@-@JAVACPP_VERSION@"
Expand Down Expand Up @@ -193,18 +197,13 @@ fork := true
```

@:choice(scala-cli)

**Warning**: This is currently **not working** due to with scala-cli not resolving mixed dependencies with and without
classifiers or with multiple classifiers. Please use the platform variant above until it is solved.

```scala
//> using scala "3.3"
//> using repository "sonatype-s01:snapshots"
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:@VERSION@"
//> using lib "org.bytedeco:pytorch:@PYTORCH_VERSION@-@JAVACPP_VERSION@,classifier=linux-x86_64-gpu"
//> using lib "org.bytedeco:openblas:@OPENBLAS_VERSION@-@JAVACPP_VERSION@,classifier=linux-x86_64"
//> using lib "org.bytedeco:cuda:@CUDA_VERSION@-@JAVACPP_VERSION@"
//> using lib "org.bytedeco:cuda:@CUDA_VERSION@-@JAVACPP_VERSION@,classifier=linux-x86_64"
//> using lib "org.bytedeco:cuda:@CUDA_VERSION@-@JAVACPP_VERSION@,classifier=linux-x86_64-redist"
```

Expand All @@ -226,12 +225,7 @@ resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
"dev.storch" %% "core" % "@VERSION@",
)
javaCppPresetLibs ++= Seq(
"pytorch-gpu" -> "@PYTORCH_VERSION@",
"openblas" -> "@OPENBLAS_VERSION@",
"cuda-redist" -> "@CUDA_VERSION@",
"cuda" -> "@CUDA_VERSION@"
)
javaCppPresetLibs ++= Seq("pytorch-gpu" -> "@PYTORCH_VERSION@", "openblas" -> "@OPENBLAS_VERSION@", "cuda-redist" -> "@CUDA_VERSION@")
fork := true
```

Expand Down
2 changes: 0 additions & 2 deletions vision/src/main/scala/torchvision/datasets/MNIST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import java.util.zip.GZIPInputStream
import scala.util.Try
import scala.util.Success
import scala.util.Failure
import torch.internal.LoadCusolver

trait MNISTBase(
val mirrors: Seq[String],
Expand All @@ -37,7 +36,6 @@ trait MNISTBase(
val train: Boolean,
val download: Boolean
) extends TensorDataset[Float32, Int64] {
LoadCusolver

private def downloadAndExtractArchive(url: URL, target: Path): Unit =
println(s"downloading from $url")
Expand Down

0 comments on commit e2c2825

Please sign in to comment.