Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for device in safetensors.torch.load_model #449

Merged
merged 4 commits into from
Apr 15, 2024
Merged

Conversation

Wauplin
Copy link
Contributor

@Wauplin Wauplin commented Mar 5, 2024

safetensors.torch.load_file has a "device" parameter to load the tensors directly to the correct device. This PR adds support for this parameter in safetensors.torch.load_model too.

(also fix device type -see #449 (comment))

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Wauplin Wauplin changed the title Fix typo Add support for device in safetensors.torch.load_model Mar 5, 2024
@Wauplin Wauplin requested review from Narsil and mishig25 March 5, 2024 10:49
When false, the function simply returns missing and unexpected names.
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
device (`Dict[str, any]`, *optional*, defaults to `cpu`):
device (`str`, *optional*, defaults to `cpu`):

is it dict or str ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took the exact same definition as in load_file docstring:

device (`Dict[str, any]`, *optional*, defaults to `cpu`):

Copy link
Collaborator

@mishig25 mishig25 Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defaults to cpu

I see. I guess it is either str or Union[Dict[str, any], str]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the change in d723ff7 :) (and used the Union[Dict[str, any], str] type).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go down to the rust code here:

#[derive(Debug, Clone, PartialEq, Eq)]
enum Device {
    Cpu,
    Cuda(usize),
    Mps,
    Npu(usize),
    Xpu(usize),
}

Copy link
Collaborator

@mishig25 mishig25 Mar 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it Union[Dict[str, any], str] rather than str ?

Doesn't the enum mean that value is str rather than dict?

enum Device {
    Cpu,
    Cuda(usize),
    Mps,
    Npu(usize),
    Xpu(usize),
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sorry so I went back to the source code and it looks like the expected type should be Union[str, int]. Here is the code:

I have updated the types accordingly in 3bad1e2. Let's wait for @Narsil's return just to be sure.

@github-actions github-actions bot added the Stale label Apr 7, 2024
@github-actions github-actions bot closed this Apr 13, 2024
@Wauplin Wauplin reopened this Apr 14, 2024
@github-actions github-actions bot removed the Stale label Apr 15, 2024
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants