Skip to content

Commit

Permalink
Fixes detection of CuPy installed with pre-built wheels (#1965)
Browse files Browse the repository at this point in the history
The CuPy library ships both a source distribution (`cupy`) as well as versions containing pre-built wheels (`cupy-cuda11x`, `cupy-cuda12x`, `cupy-rocm-5-0`, `cupy-rocm-4-3`). Use of `_is_package_available` to detect CuPy only works for the source distribution of CuPy and fails when using the pre-built wheels versions.

This is because the `_is_package_available` will always attempt to resolve version information (even if it's not required) and in doing so assumes that the _importable_ package name matches the _installed_ distribution name. While this is usually the case, it doesn't work for CuPy and several other libraries. ONNX Runtime for example might be installed as `onnxruntime` or `onnxruntime-gpu` and thus Optimum just uses `importlib.util.find_spec` to work around the same problem. This commit replicates the same solution for CuPy.
  • Loading branch information
tcsavage authored Sep 16, 2024
1 parent 26949f5 commit f1b708c
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Utility functions, classes and constants for ONNX Runtime."""

import importlib
import os
import re
from enum import Enum
Expand All @@ -31,7 +32,6 @@
import onnxruntime as ort

from ..exporters.onnx import OnnxConfig, OnnxConfigWithLoss
from ..utils.import_utils import _is_package_available


if TYPE_CHECKING:
Expand Down Expand Up @@ -91,9 +91,11 @@ def is_onnxruntime_training_available():

def is_cupy_available():
"""
Checks if onnxruntime-training is available.
Checks if CuPy is available.
"""
return _is_package_available("cupy")
# Don't use _is_package_available as it doesn't work with CuPy installed
# with `cupy-cuda*` and `cupy-rocm-*` package name (prebuilt wheels).
return importlib.util.find_spec("cupy") is not None


class ORTConfigManager:
Expand Down

0 comments on commit f1b708c

Please sign in to comment.