Skip to content

Commit

Permalink
[docs][pytorch] Add examples for compiling with external weights. (#1…
Browse files Browse the repository at this point in the history
…8658)

Progress on #18564.

Adds examples to the PyTorch guide, showing how to externalize module
parameters, and load them at runtime, both through command line
(`iree-run-module`) and through the iree-runtime Python API (using
`ParameterIndex`).
  • Loading branch information
vinayakdsci authored Oct 3, 2024
1 parent 206c1f2 commit 718b4fd
Showing 1 changed file with 121 additions and 0 deletions.
121 changes: 121 additions & 0 deletions docs/website/docs/guides/ml-frameworks/pytorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,127 @@ their values independently at runtime.
self.value = new_value
```

#### :octicons-file-symlink-file-16: Using external parameters

Model parameters can be stored in standalone files that can be efficiently
stored and loaded separately from model compute graphs. See the
[Parameters guide](../parameters.md) for more general information about
parameters in IREE.

When using iree-turbine, the `aot.externalize_module_parameters()` function
separates parameters from program modules and encodes a symbolic relationship
between them so they can be loaded at runtime.

We use [Safetensors](https://huggingface.co/docs/safetensors/) here to store the
models parameters on disk, so that they can be loaded later during runtime.

```python
import torch
from safetensors.torch import save_file
import numpy as np
import shark_turbine.aot as aot

class LinearModule(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
self.bias = torch.nn.Parameter(torch.randn(out_features))

def forward(self, input):
return (input @ self.weight) + self.bias

linear_module = LinearModule(4,3)

# Create a params dictionary. Note that the keys here match LinearModule's
# attributes. We will use the saved safetensor file for use from the command
# line.
wt = linear_module.weight.t().contiguous()
bias = linear_module.bias.t().contiguous()
params = { "weight": wt, "bias": bias }
save_file(params, "params.safetensors")

# Externalize the model parameters. This removes weight tensors from the IR
# module, allowing them to be loaded at runtime. Symbolic references to these
# parameters are still retained in the IR.
aot.externalize_module_parameters(linear_module)

input = torch.randn(4)
exported_module = aot.export(linear_module, input)

# Compile the exported module, to generate the binary. When `save_to` is
# not None, the binary will be stored at the path passed in to `save_to`.
# Here, we pass in None, so that the binary can stored in a variable.
binary = exported_module.compile(save_to=None)

# Save the input as an npy tensor, so that it can be passed in through the
# command line to `iree-run-module`.
input_np = input.numpy()
np.save("input.npy", input_np)
```

=== "Python runtime"

Runtime invocation now requires loading the parameters as a separate module.
To get the parameters as a module, iree.runtime provides a convenient method,
called `create_io_parameters_module()`.

```python
import iree.runtime as ireert

# To load the parameters, we need to define ParameterIndex for each
# parameter class.
idx = ireert.ParameterIndex()
idx.add_buffer("weight", wt.detach().numpy().tobytes())
idx.add_buffer("bias", bias.detach().numpy().tobytes())


# Create the runtime instance, and load the runtime.
config = ireert.Config(driver_name="local-task")
instance = config.vm_instance

param_module = ireert.create_io_parameters_module(
instance, idx.create_provider(scope="model"),
)

# Load the runtime. There are essentially two modules to load, one for the
# weights, and one for the main module. Ensure that the VMFB file is not
# already open or deleted before use.
vm_modules = ireert.load_vm_modules(
param_module,
ireert.create_hal_module(instance, config.device),
ireert.VmModule.copy_buffer(instance, binary.map_memory()),
config=config,
)

# vm_modules is a list of modules. The last module in the list is the one
# generated from the binary, so we use that to generate an output.
result = vm_modules[-1].main(input)
print(result.to_host())
```

=== "Command line tools"

It is also possible to save the VMFB binary to disk, then call `iree-run-module`
through the command line to generate outputs.

```python
# When save_to is not None, the binary is saved to the given path,
# and a None value is returned.
binary = exported_module.compile(save_to="compiled_module.vmfb")
```

The stored safetensors file, the input tensor, and the VMFB can now be passed
in to IREE through the command line.

```bash
iree-run-module --module=compiled_module.vmfb --parameters=model=params.safetensors \
--input=@input.npy
```

Note here that the `--parameters` flag has `model=` following it immediately.
This simply specifies the scope of the parameters, and is reflected in the
compiled module.

#### :octicons-code-16: Samples

| Code samples | |
Expand Down

0 comments on commit 718b4fd

Please sign in to comment.