-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #31 from sbrunk/cuda12
Upgrade to cuda 12 and add workaround for missing libcusolver
- Loading branch information
Showing
6 changed files
with
75 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/* | ||
* Copyright 2022 storch.dev | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package torch | ||
package internal | ||
|
||
import org.bytedeco.javacpp.Loader | ||
|
||
// This is a workaround for https://github.com/bytedeco/javacpp-presets/issues/1376 | ||
// TODO remove once the issue is fixed | ||
object LoadCusolver { | ||
try { | ||
val cusolver = Class.forName("org.bytedeco.cuda.global.cusolver") | ||
Loader.load(cusolver) | ||
} catch { | ||
case e: ClassNotFoundException => // ignore to avoid breaking CPU only builds | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Frequently Asked Questions | ||
|
||
## Q: I want to run operations on the GPU, but Storch seems to hang? | ||
|
||
Depending on your hardware, the CUDA version and capability settings, CUDA might need to do [just-in-time compilation]() | ||
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#just-in-time-compilation) of your kernels, which | ||
can take a few minutes. The result is cached, so it should load faster on subsequent runs. | ||
|
||
If you're unsure, you can watch the size of the cache: | ||
|
||
```bash | ||
watch -d du -sm ~/.nv/ComputeCache | ||
``` | ||
If it's still growing, it's very likely that CUDA is doing just-in-time compilation. | ||
|
||
You can also increase the cache size to up to 4GB, to avoid recomputation: | ||
|
||
```bash | ||
export CUDA_CACHE_MAXSIZE=4294967296 | ||
``` | ||
|
||
|
||
## Q: What about GPU support on my Mac? | ||
|
||
Recent PyTorch versions provide a new backend based on Apple’s Metal Performance Shaders (MPS). | ||
The MPS backend enables GPU-accelerated training on the M1/M2 architecture. | ||
Right now, there's no ARM build of PyTorch in JavaCPP and MPS ist not enabled. | ||
If you have an M1/M2 machine and want to help, check the umbrella [issue for macosx-aarch64 support](https://github.com/bytedeco/javacpp-presets/issues/1069). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters