Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs][pytorch] Add examples for compiling with external weights. #18658

Merged
merged 3 commits into from
Oct 3, 2024

Conversation

vinayakdsci
Copy link
Contributor

@vinayakdsci vinayakdsci commented Oct 1, 2024

Progress on #18564.

Adds examples to the PyTorch guide, showing how to externalize module parameters, and load them at runtime, both through command line (iree-run-module) and through the iree-runtime Python API (using ParameterIndex).

@ScottTodd ScottTodd added documentation ✏️ Improvements or additions to documentation integrations/pytorch PyTorch integration work labels Oct 1, 2024
Copy link
Contributor

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I'm getting a sense for how this external parameter thing is working. Some comments based on a fairly naive read-through.

@@ -322,6 +322,132 @@ their values independently at runtime.
self.value = new_value
```

#### :octicons-code-16: Separating parameters during compilation

In practical scenarios, it is usually unfeasible to embed model parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: infeasible is more common.

nit: instead of "In practical scenarios", maybe put "For large models", since this is the practical scenario where this applies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will do.

Comment on lines 352 to 365
Safetensors can be created directly from a parameters dictionary created from
randomly generated torch tensors.

```python
# Create a params dictionary. Note that the keys here match
# LinearModule's attributes. We will use the saved safetensor
# file for use from the command line.
from safetensors.torch import save_file

wt = torch.randn(4,3)
bias = torch.randn(3)
params = { "weight": wt, "bias": bias }
save_file(params, "params.safetensors")
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm finding this section misleading from the example above. The linear_module already has specific weight and bias parameters. Why are we storing new random ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as an example! We might not always have the class that actually represents the model, but just some representation of the graph. This is a demonstration of how, in such a case, one can load the weights separately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zjgarvey I changed the examples to use the modules parameters.

# tensors from the IR module, allowing them to be loaded
# at runtime. Symbolic references to these parameters are
# still retained in the IR.
aot.externalize_module_parameters(linear_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this externalize the weights to? Does the file have to be called params.safetensors in the PWD or something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading further on, it looks like you specify the params file when invoking iree-run-module. I suppose the symbolic reference is to the name of the parameter in the safetensors dictionary you saved earlier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No! It can be any path you specify. Here it is this file in PWD.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading further on, it looks like you specify the params file when invoking iree-run-module. I suppose the symbolic reference is to the name of the parameter in the safetensors dictionary you saved earlier?

Yes, and under the scope named "model".

docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
Comment on lines 447 to 449
Note here that the `--parameters` flag has `model=` following it immediately.
This simply specifies the scope of the parameters, and is reflected in the
compiled IR.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this. In what situations are parameters not model wide?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model is a scope here. It could actually be anything. If is even possible that the biases could be under a scope called "b", and weights under a scope called "a". In that case, they would be in different indices, and model= would be then b= and a= respectively.

Copy link
Member

@ScottTodd ScottTodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sending my initial comments since a few already overlap with Zach's :)

docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
docs/website/docs/guides/ml-frameworks/pytorch.md Outdated Show resolved Hide resolved
@stellaraccident
Copy link
Collaborator

Thanks for doing this. Once landed, can someone make sure to add a test to iree-turbine which does this exact thing?

@vinayakdsci
Copy link
Contributor Author

@stellaraccident Sure! I'd be happy to do that.

Copy link
Member

@ScottTodd ScottTodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks!

Copy link
Contributor

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this looks good to me. Thanks for the changes!

@vinayakdsci vinayakdsci merged commit 718b4fd into iree-org:main Oct 3, 2024
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation ✏️ Improvements or additions to documentation integrations/pytorch PyTorch integration work
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants