diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 43bef5257..572835481 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -40,6 +40,7 @@ jobs: - name: Install dependencies - macOS run: | echo "OMP_NUM_THREADS=1" >> $GITHUB_ENV + echo "PYTORCH_MPS_DISABLE=1" >> $GITHUB_ENV if: matrix.os == 'macos-latest' - name: Install dependencies - Windows diff --git a/src/python/txtai/models/models.py b/src/python/txtai/models/models.py index 267884b30..c3f50c593 100644 --- a/src/python/txtai/models/models.py +++ b/src/python/txtai/models/models.py @@ -132,7 +132,7 @@ def reference(deviceid): else f"cuda:{deviceid}" if torch.cuda.is_available() else "mps" - if torch.backends.mps.is_available() + if Models.hasmpsdevice() else Models.finddevice() ) @@ -145,7 +145,18 @@ def hasaccelerator(): True if an accelerator device is available, False otherwise """ - return torch.cuda.is_available() or torch.backends.mps.is_available() or bool(Models.finddevice()) + return torch.cuda.is_available() or Models.hasmpsdevice() or bool(Models.finddevice()) + + @staticmethod + def hasmpsdevice(): + """ + Checks if there is a MPS device available. + + Returns: + True if a MPS device is available, False otherwise + """ + + return os.environ.get("PYTORCH_MPS_DISABLE") != "1" and torch.backends.mps.is_available() @staticmethod def finddevice():