Skip to content

Commit

Permalink
Add env variable to disable macOS MPS devices, closes #592
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Nov 1, 2023
1 parent df3b810 commit e7552a6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
- name: Install dependencies - macOS
run: |
echo "OMP_NUM_THREADS=1" >> $GITHUB_ENV
echo "PYTORCH_MPS_DISABLE=1" >> $GITHUB_ENV
if: matrix.os == 'macos-latest'

- name: Install dependencies - Windows
Expand Down
15 changes: 13 additions & 2 deletions src/python/txtai/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def reference(deviceid):
else f"cuda:{deviceid}"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
if Models.hasmpsdevice()
else Models.finddevice()
)

Expand All @@ -145,7 +145,18 @@ def hasaccelerator():
True if an accelerator device is available, False otherwise
"""

return torch.cuda.is_available() or torch.backends.mps.is_available() or bool(Models.finddevice())
return torch.cuda.is_available() or Models.hasmpsdevice() or bool(Models.finddevice())

@staticmethod
def hasmpsdevice():
"""
Checks if there is a MPS device available.
Returns:
True if a MPS device is available, False otherwise
"""

return os.environ.get("PYTORCH_MPS_DISABLE") != "1" and torch.backends.mps.is_available()

@staticmethod
def finddevice():
Expand Down

0 comments on commit e7552a6

Please sign in to comment.