Skip to content

Commit

Permalink
Merge pull request #31 from sbrunk/cuda12
Browse files Browse the repository at this point in the history
Upgrade to cuda 12 and add workaround for missing libcusolver
  • Loading branch information
sbrunk authored Jun 25, 2023
2 parents 89c4d5f + a573b39 commit bfcaab4
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 7 deletions.
7 changes: 5 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ ThisBuild / apiURL := Some(new URL("https://storch.dev/api/"))

val scrImageVersion = "4.0.34"
val pytorchVersion = "2.0.1"
val cudaVersion = "12.1-8.9"
val openblasVersion = "0.3.23"
val mklVersion = "2023.1"
ThisBuild / scalaVersion := "3.3.0"
Expand Down Expand Up @@ -74,7 +75,8 @@ lazy val core = project
(if (enableGPU.value) "pytorch-gpu" else "pytorch") -> pytorchVersion,
"mkl" -> mklVersion,
"openblas" -> openblasVersion
) ++ (if (enableGPU.value) Seq("cuda-redist" -> "11.8-8.6") else Seq()),
// TODO remove cuda (not cuda-redist) once https://github.com/bytedeco/javacpp-presets/issues/1376 is fixed
) ++ (if (enableGPU.value) Seq("cuda-redist" -> cudaVersion, "cuda" -> cudaVersion) else Seq()),
javaCppPlatform := org.bytedeco.sbt.javacpp.Platform.current,
fork := true,
Test / fork := true,
Expand Down Expand Up @@ -125,7 +127,8 @@ lazy val docs = project
"JAVACPP_VERSION" -> javaCppVersion.value,
"PYTORCH_VERSION" -> pytorchVersion,
"OPENBLAS_VERSION" -> openblasVersion,
"MKL_VERSION" -> mklVersion
"MKL_VERSION" -> mklVersion,
"CUDA_VERSION" -> cudaVersion
),
ScalaUnidoc / unidoc / unidocProjectFilter := inAnyProject -- inProjects(examples),
Laika / sourceDirectories ++= Seq(sourceDirectory.value),
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import org.bytedeco.pytorch.DoubleArrayRef
import org.bytedeco.pytorch.EllipsisIndexType
import org.bytedeco.pytorch.SymInt
import org.bytedeco.pytorch.SymIntOptional
import internal.LoadCusolver

case class TensorTuple[D <: DType](
values: Tensor[D],
Expand Down Expand Up @@ -729,6 +730,8 @@ type IntTensor = UInt8Tensor | Int8Tensor | Int16Tensor | Int32Tensor | Int64Ten
type ComplexTensor = Complex32Tensor | Complex64Tensor | Complex128Tensor

object Tensor:
LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

def apply[D <: DType](native: pytorch.Tensor): Tensor[D] = (native.scalar_type().intern() match
case ScalarType.Byte => new UInt8Tensor(native)
case ScalarType.Char => new Int8Tensor(native)
Expand Down
32 changes: 32 additions & 0 deletions core/src/main/scala/torch/internal/LoadCusolver.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2022 storch.dev
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package torch
package internal

import org.bytedeco.javacpp.Loader

// This is a workaround for https://github.com/bytedeco/javacpp-presets/issues/1376
// TODO remove once the issue is fixed
object LoadCusolver {
try {
val cusolver = Class.forName("org.bytedeco.cuda.global.cusolver")
Loader.load(cusolver)
} catch {
case e: ClassNotFoundException => // ignore to avoid breaking CPU only builds
}

}
2 changes: 2 additions & 0 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import scala.annotation.targetName

private[torch] object NativeConverters:

LoadCusolver // TODO workaround for https://github.com/bytedeco/javacpp-presets/issues/1376

inline def convertToOptional[T, U <: T | Option[T], V >: Null](i: U, f: T => V): V = i match
case i: Option[T] => i.map(f(_)).orNull
case i: T => f(i)
Expand Down
28 changes: 28 additions & 0 deletions docs/faq.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Frequently Asked Questions

## Q: I want to run operations on the GPU, but Storch seems to hang?

Depending on your hardware, the CUDA version and capability settings, CUDA might need to do [just-in-time compilation]()
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#just-in-time-compilation) of your kernels, which
can take a few minutes. The result is cached, so it should load faster on subsequent runs.

If you're unsure, you can watch the size of the cache:

```bash
watch -d du -sm ~/.nv/ComputeCache
```
If it's still growing, it's very likely that CUDA is doing just-in-time compilation.

You can also increase the cache size to up to 4GB, to avoid recomputation:

```bash
export CUDA_CACHE_MAXSIZE=4294967296
```


## Q: What about GPU support on my Mac?

Recent PyTorch versions provide a new backend based on Apple’s Metal Performance Shaders (MPS).
The MPS backend enables GPU-accelerated training on the M1/M2 architecture.
Right now, there's no ARM build of PyTorch in JavaCPP and MPS ist not enabled.
If you have an M1/M2 machine and want to help, check the umbrella [issue for macosx-aarch64 support](https://github.com/bytedeco/javacpp-presets/issues/1069).
10 changes: 5 additions & 5 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
"dev.storch" %% "core" % "@VERSION@",
"org.bytedeco" % "pytorch-platform-gpu" % "@PYTORCH_VERSION@-@JAVACPP_VERSION@",
"org.bytedeco" % "cuda-platform-redist" % "11.8-8.6-@JAVACPP_VERSION@"
"org.bytedeco" % "cuda-platform-redist" % "@CUDA_VERSION@-@JAVACPP_VERSION@"
)
fork := true
```
Expand All @@ -165,7 +165,7 @@ fork := true
//> using repository "sonatype:snapshots"
//> using lib "dev.storch::core:@VERSION@"
//> using lib "org.bytedeco:pytorch-platform-gpu:@PYTORCH_VERSION@-@JAVACPP_VERSION@"
//> using lib "org.bytedeco:cuda-platform-redist:11.8-8.6-@JAVACPP_VERSION@"
//> using lib "org.bytedeco:cuda-platform-redist:@CUDA_VERSION@-@JAVACPP_VERSION@"
```

@:@
Expand All @@ -189,7 +189,7 @@ libraryDependencies += Seq(
"org.bytedeco" % "pytorch" % "@PYTORCH_VERSION@-@JAVACPP_VERSION@",
"org.bytedeco" % "pytorch" % "@PYTORCH_VERSION@-@JAVACPP_VERSION@" classifier "linux-x86_64-gpu",
"org.bytedeco" % "openblas" % "@OPENBLAS_VERSION@-@JAVACPP_VERSION@" classifier "linux-x86_64",
"org.bytedeco" % "cuda" % "11.8-8.6-@JAVACPP_VERSION@" classifier "linux-x86_64-redist"
"org.bytedeco" % "cuda" % "@CUDA_VERSION@-@JAVACPP_VERSION@" classifier "linux-x86_64-redist"
)
fork := true
```
Expand All @@ -202,7 +202,7 @@ fork := true
//> using lib "dev.storch::core:@VERSION@"
//> using lib "org.bytedeco:pytorch:@PYTORCH_VERSION@-@JAVACPP_VERSION@,classifier=linux-x86_64-gpu"
//> using lib "org.bytedeco:openblas:@OPENBLAS_VERSION@-@JAVACPP_VERSION@,classifier=linux-x86_64"
//> using lib "org.bytedeco:cuda:11.8-8.6-@JAVACPP_VERSION@,classifier=linux-x86_64-redist"
//> using lib "org.bytedeco:cuda:@CUDA_VERSION@-@JAVACPP_VERSION@,classifier=linux-x86_64-redist"
```

@:@
Expand All @@ -223,7 +223,7 @@ resolvers ++= Resolver.sonatypeOssRepos("snapshots")
libraryDependencies += Seq(
"dev.storch" %% "core" % "@VERSION@",
)
javaCppPresetLibs ++= Seq("pytorch-gpu" -> "@PYTORCH_VERSION@", "openblas" -> "@OPENBLAS_VERSION@", "cuda-redist" -> "11.8-8.6")
javaCppPresetLibs ++= Seq("pytorch-gpu" -> "@PYTORCH_VERSION@", "openblas" -> "@OPENBLAS_VERSION@", "cuda-redist" -> "@CUDA_VERSION@")
fork := true
```

Expand Down

0 comments on commit bfcaab4

Please sign in to comment.