Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Oct 2, 2024
1 parent 5bb8e79 commit 0d91d52
Showing 1 changed file with 85 additions and 90 deletions.
175 changes: 85 additions & 90 deletions docs/website/docs/guides/ml-frameworks/pytorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -322,20 +322,26 @@ their values independently at runtime.
self.value = new_value
```

#### :octicons-code-16: Separating parameters during compilation
#### :octicons-file-symlink-file-16: Using external parameters

In practical scenarios, it is usually unfeasible to embed model parameters
in the IR.
Model parameters can be stored in standalone files that can be efficiently
stored and loaded separately from model compute graphs. See the
[Parameters guide](../parameters.md) for more general information about
parameters in IREE.

`aot.externalize_module_parameters()` allows for the separation of
the parameters from the IR, but still encodes a symbolic relationship between
between parameters and the IR, so they can be loaded at runtime.
When using iree-turbine, the `aot.externalize_module_parameters()` function
separates parameters from program modules and encodes a symbolic relationship
between them so they can be loaded at runtime.

We define a `torch.nn.Module` that will be compiled, and save
random tensors to be loaded as parameters in a safetensors file.
We use [Safetensors](https://huggingface.co/docs/safetensors/) here to store the
models parameters on disk, so that they can be loaded later during runtime.

```python
import torch
from safetensors.torch import save_file
import numpy as np
import shark_turbine.aot as aot

class LinearModule(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
Expand All @@ -345,108 +351,97 @@ class LinearModule(torch.nn.Module):
def forward(self, input):
return (input @ self.weight) + self.bias

# Construct LinearModule.
linear_module = LinearModule(4,3)
```

Safetensors can be created directly from a parameters dictionary created from
randomly generated torch tensors.

```python
# Create a params dictionary. Note that the keys here match
# LinearModule's attributes. We will use the saved safetensor
# file for use from the command line.
from safetensors.torch import save_file

wt = torch.randn(4,3)
bias = torch.randn(3)
# Create a params dictionary. Note that the keys here match LinearModule's
# attributes. We will use the saved safetensor file for use from the command
# line.
wt = linear_module.weight.t().contiguous()
bias = linear_module.bias.t().contiguous()
params = { "weight": wt, "bias": bias }
save_file(params, "params.safetensors")
```

Exporting through `aot.export()` and compiling the produced output generates a binary.

```python
import numpy as np
import shark_turbine.aot as aot

# Externalize the model parameters. This removes weight
# tensors from the IR module, allowing them to be loaded
# at runtime. Symbolic references to these parameters are
# still retained in the IR.
# Externalize the model parameters. This removes weight tensors from the IR
# module, allowing them to be loaded at runtime. Symbolic references to these
# parameters are still retained in the IR.
aot.externalize_module_parameters(linear_module)

ext_arg = torch.randn(4)
exported_module = aot.export(linear_module, ext_arg)
input = torch.randn(4)
exported_module = aot.export(linear_module, input)

# Compile the exported module, to generate the binary.
# Compile the exported module, to generate the binary. When `save_to` is
# not None, the binary will be stored at the path passed in to `save_to`.
# Here, we pass in None, so that the binary can stored in a variable.
binary = exported_module.compile(save_to=None)

# Save the input as an npy tensor, for later reuse.
ext_arg_2 = ext_arg.numpy()
np.save("input.npy", ext_arg_2)
# Save the input as an npy tensor, so that it can be passed in through the
# command line to `iree-run-module`.
input_np = input.numpy()
np.save("input.npy", input_np)
```

Runtime invocation now requires loading the parameters as a separate module.
To get the parameters as a module, iree.runtime provides a convenient method,
called `create_io_parameters_module()`.

```python
import iree.runtime as ireert
=== "Python runtime"

# To load the parameters, we need to define ParameterIndex for each
# parameter class.
Runtime invocation now requires loading the parameters as a separate module.
To get the parameters as a module, iree.runtime provides a convenient method,
called `create_io_parameters_module()`.

idx = ireert.ParameterIndex()
idx.add_buffer("weight", wt.numpy().tobytes())
idx.add_buffer("bias", bias.numpy().tobytes())


# Create the runtime instance, and load the runtime.
```python
import iree.runtime as ireert

# To load the parameters, we need to define ParameterIndex for each
# parameter class.
idx = ireert.ParameterIndex()
idx.add_buffer("weight", wt.detach().numpy().tobytes())
idx.add_buffer("bias", bias.detach().numpy().tobytes())


# Create the runtime instance, and load the runtime.
config = ireert.Config(driver_name="local-task")
instance = config.vm_instance

param_module = ireert.create_io_parameters_module(
instance, idx.create_provider(scope="model"),
)

# Load the runtime. There are essentially two modules to load, one for the
# weights, and one for the main module. Ensure that the VMFB file is not
# already open or deleted before use.
vm_modules = ireert.load_vm_modules(
param_module,
ireert.create_hal_module(instance, config.device),
ireert.VmModule.copy_buffer(instance, binary.map_memory()),
config=config,
)

# vm_modules is a list of modules. The last module in the list is the one
# generated from the binary, so we use that to generate an output.
result = vm_modules[-1].main(input)
print(result.to_host())
```

config = ireert.Config(driver_name="local-task")
instance = config.vm_instance
=== "Command line tools"

param_module = ireert.create_io_parameters_module(
instance, idx.create_provider(scope="model"),
)
It is also possible to save the VMFB binary to disk, then call `iree-run-module`
through the command line to generate outputs.

# Load the runtime. There are essentially two modules to load,
# one for the weights, and one for the main module.
# Ensure that the VMFB file is not already open or deleted before use.
vm_modules = ireert.load_vm_modules(
param_module,
ireert.create_hal_module(instance, config.device),
ireert.VmModule.copy_buffer(instance, binary.map_memory()),
config=config,
)

# vm_modules is a list of modules. The last module in the list is the one
# generated from the binary, so we use that to generate an output.
result = vm_modules[-1].main(ext_arg)
print(result.to_host())
```

It is possible to save the VMFB binary to disk, then run `iree-run-module` through
the command line to generate outputs.

```python
# When save_to is not None, the binary is saved to the given path,
# and a None value is returned.
binary = exported_module.compile(save_to="compiled_module.vmfb")
```
```python
# When save_to is not None, the binary is saved to the given path,
# and a None value is returned.
binary = exported_module.compile(save_to="compiled_module.vmfb")
```

The stored safetensors file, the input tensor, and the VMFB can now be passed
in to IREE through the command line.
The stored safetensors file, the input tensor, and the VMFB can now be passed
in to IREE through the command line.

```bash
iree-run-module --module=compiled_module.vmfb --parameters=model=params.safetensors \
--input=@input.npy
```
```bash
iree-run-module --module=compiled_module.vmfb --parameters=model=params.safetensors \
--input=@input.npy
```

Note here that the `--parameters` flag has `model=` following it immediately.
This simply specifies the scope of the parameters, and is reflected in the
compiled IR.
Note here that the `--parameters` flag has `model=` following it immediately.
This simply specifies the scope of the parameters, and is reflected in the
compiled module.

#### :octicons-code-16: Samples

Expand Down

0 comments on commit 0d91d52

Please sign in to comment.