Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve device support and add support for Apple Silicon chipset (mps) #34

Merged

Conversation

gonlairo
Copy link
Contributor

@gonlairo gonlairo commented Jun 30, 2023

This pull request improves device support for sklearn API models, enabling transfer between CPU, CUDA and Apple Silicon GPU environments.

Fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/631

@stes stes force-pushed the rodrigo/move-model-to-device branch from 663e50d to e1e6848 Compare July 3, 2023 21:13
@stes stes self-requested a review July 3, 2023 21:15
@stes stes added the enhancement New feature or request label Jul 3, 2023
@stes stes force-pushed the rodrigo/move-model-to-device branch from 11c6acf to 42965ea Compare July 5, 2023 23:23
@rob-the-bot
Copy link

Based on the branch from this pull request, I've also extended the device support to more generic GPU devices with torch-directml. It's still under active development but already has most of functionalities. I've tested on Windows 10 22H2 with AMD RX 5500 XT, the Demo_decoding notebook works (see image below).

Is the dev team also interested in making CEBRA available to AMD/Intel GPU users? If yes I'm happy to include my changes to this pull request. The changes to the CEBRA codebase is minimal, and torch-directml needs to be imported during runtime.

image

@stes stes changed the title Improved Device Support for Models Using sklearn API Improve device support and add support for Apple Silicon chipset (mps) Jul 13, 2023
Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add more tests to verify the functionality.

cebra/integrations/sklearn/utils.py Show resolved Hide resolved
@gonlairo gonlairo requested a review from stes July 14, 2023 14:33
cebra/integrations/sklearn/utils.py Show resolved Hide resolved
@stes stes requested a review from MMathisLab July 17, 2023 12:31
Copy link
Member

@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm - didn't directly test myself though :)

@gonlairo
Copy link
Contributor Author

test with CUDA:
Screen Shot 2023-07-17 at 7 17 29 PM

@stes stes merged commit 00601fb into AdaptiveMotorControlLab:main Jul 17, 2023
8 checks passed
@stes stes deleted the rodrigo/move-model-to-device branch July 17, 2023 21:25
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA signed enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants