Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs][pytorch] Add examples for compiling with external weights. #18658

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions docs/website/docs/guides/ml-frameworks/pytorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,132 @@ their values independently at runtime.
self.value = new_value
```

#### :octicons-code-16: Separating parameters during compilation
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved

In practical scenarios, it is usually unfeasible to embed model parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: infeasible is more common.

nit: instead of "In practical scenarios", maybe put "For large models", since this is the practical scenario where this applies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will do.

in the IR.

`aot.externalize_module_parameters()` allows for the separation of
the parameters from the IR, but still encodes a symbolic relationship between
between parameters and the IR, so they can be loaded at runtime.
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved

We define a `torch.nn.Module` that will be compiled, and save
random tensors to be loaded as parameters in a safetensors file.
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved

```python
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
import torch
class LinearModule(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
self.bias = torch.nn.Parameter(torch.randn(out_features))

def forward(self, input):
return (input @ self.weight) + self.bias

# Construct LinearModule.
linear_module = LinearModule(4,3)
ScottTodd marked this conversation as resolved.
Show resolved Hide resolved
```

Safetensors can be created directly from a parameters dictionary created from
randomly generated torch tensors.

```python
# Create a params dictionary. Note that the keys here match
# LinearModule's attributes. We will use the saved safetensor
# file for use from the command line.
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
from safetensors.torch import save_file

wt = torch.randn(4,3)
bias = torch.randn(3)
params = { "weight": wt, "bias": bias }
save_file(params, "params.safetensors")
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm finding this section misleading from the example above. The linear_module already has specific weight and bias parameters. Why are we storing new random ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as an example! We might not always have the class that actually represents the model, but just some representation of the graph. This is a demonstration of how, in such a case, one can load the weights separately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zjgarvey I changed the examples to use the modules parameters.


Exporting through `aot.export()` and compiling the produced output generates a binary.

```python
import numpy as np
import shark_turbine.aot as aot

# Externalize the model parameters. This removes weight
# tensors from the IR module, allowing them to be loaded
# at runtime. Symbolic references to these parameters are
# still retained in the IR.
aot.externalize_module_parameters(linear_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this externalize the weights to? Does the file have to be called params.safetensors in the PWD or something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading further on, it looks like you specify the params file when invoking iree-run-module. I suppose the symbolic reference is to the name of the parameter in the safetensors dictionary you saved earlier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No! It can be any path you specify. Here it is this file in PWD.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading further on, it looks like you specify the params file when invoking iree-run-module. I suppose the symbolic reference is to the name of the parameter in the safetensors dictionary you saved earlier?

Yes, and under the scope named "model".


ext_arg = torch.randn(4)
exported_module = aot.export(linear_module, ext_arg)
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
ScottTodd marked this conversation as resolved.
Show resolved Hide resolved

# Compile the exported module, to generate the binary.
binary = exported_module.compile(save_to=None)

# Save the input as an npy tensor, for later reuse.
ext_arg_2 = ext_arg.numpy()
np.save("input.npy", ext_arg_2)
vinayakdsci marked this conversation as resolved.
Show resolved Hide resolved
```

Runtime invocation now requires loading the parameters as a separate module.
To get the parameters as a module, iree.runtime provides a convenient method,
called `create_io_parameters_module()`.

```python
import iree.runtime as ireert

# To load the parameters, we need to define ParameterIndex for each
# parameter class.

idx = ireert.ParameterIndex()
idx.add_buffer("weight", wt.numpy().tobytes())
idx.add_buffer("bias", bias.numpy().tobytes())


# Create the runtime instance, and load the runtime.

config = ireert.Config(driver_name="local-task")
instance = config.vm_instance

param_module = ireert.create_io_parameters_module(
instance, idx.create_provider(scope="model"),
)

# Load the runtime. There are essentially two modules to load,
# one for the weights, and one for the main module.
# Ensure that the VMFB file is not already open or deleted before use.
vm_modules = ireert.load_vm_modules(
param_module,
ireert.create_hal_module(instance, config.device),
ireert.VmModule.copy_buffer(instance, binary.map_memory()),
config=config,
)

# vm_modules is a list of modules. The last module in the list is the one
# generated from the binary, so we use that to generate an output.
result = vm_modules[-1].main(ext_arg)
print(result.to_host())
```

It is possible to save the VMFB binary to disk, then run `iree-run-module` through
the command line to generate outputs.

```python
# When save_to is not None, the binary is saved to the given path,
# and a None value is returned.
binary = exported_module.compile(save_to="compiled_module.vmfb")
```

The stored safetensors file, the input tensor, and the VMFB can now be passed
in to IREE through the command line.

```bash
iree-run-module --module=compiled_module.vmfb --parameters=model=params.safetensors \
--input=@input.npy
```
ScottTodd marked this conversation as resolved.
Show resolved Hide resolved

Note here that the `--parameters` flag has `model=` following it immediately.
This simply specifies the scope of the parameters, and is reflected in the
compiled IR.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this. In what situations are parameters not model wide?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model is a scope here. It could actually be anything. If is even possible that the biases could be under a scope called "b", and weights under a scope called "a". In that case, they would be in different indices, and model= would be then b= and a= respectively.


#### :octicons-code-16: Samples

| Code samples | |
Expand Down
Loading