Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image encoder onnx export #326

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ python scripts/export_onnx_model.py --checkpoint <path/to/checkpoint> --model-ty

See the [example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export.

If you want to export SAM`s image encoder so that it can be run in any environment that supports ONNX runtime. Export the model with

```
python scripts/export_image_encoder.py --checkpoint <path/to/checkpoint> --model-type <model_type> --output <path/to/output>
```

### Web demo

The `demo/` folder has a simple one page React app which shows how to run mask prediction with the exported ONNX model in a web browser with multithreading. Please see [`demo/README.md`](https://github.com/facebookresearch/segment-anything/blob/main/demo/README.md) for more details.
Expand Down
206 changes: 206 additions & 0 deletions scripts/export_image_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import numpy as np

from segment_anything import sam_model_registry
from segment_anything.utils.onnx import ImageEncoderOnnxModel

import os
import argparse
import warnings

try:
import onnxruntime # type: ignore

onnxruntime_exists = True
except ImportError:
onnxruntime_exists = False

parser = argparse.ArgumentParser(
description="Export the SAM image encoder to an ONNX model."
)

parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="The path to the SAM model checkpoint.",
)

parser.add_argument(
"--output", type=str, required=True, help="The filename to save the ONNX model to."
)

parser.add_argument(
"--model-type",
type=str,
required=True,
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
)

parser.add_argument(
"--use-preprocess",
action="store_true",
help="Whether to preprocess the image by resizing, standardizing, etc.",
)

parser.add_argument(
"--opset",
type=int,
default=17,
help="The ONNX opset version to use. Must be >=11",
)

parser.add_argument(
"--quantize-out",
type=str,
default=None,
help=(
"If set, will quantize the model and save it with this name. "
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
),
)

parser.add_argument(
"--gelu-approximate",
action="store_true",
help=(
"Replace GELU operations with approximations using tanh. Useful "
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
),
)


def run_export(
model_type: str,
checkpoint: str,
output: str,
use_preprocess: bool,
opset: int,
gelu_approximate: bool = False,
):
print("Loading model...")
sam = sam_model_registry[model_type](checkpoint=checkpoint)

onnx_model = ImageEncoderOnnxModel(
model=sam,
use_preprocess=use_preprocess,
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)

if gelu_approximate:
for n, m in onnx_model.named_modules():
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"

image_size = sam.image_encoder.img_size
if use_preprocess:
dummy_input = {
"input_image": torch.randn((image_size, image_size, 3), dtype=torch.float)
}
dynamic_axes = {
"input_image": {0: "image_height", 1: "image_width"},
}
else:
dummy_input = {
"input_image": torch.randn(
(1, 3, image_size, image_size), dtype=torch.float
)
}
dynamic_axes = None

_ = onnx_model(**dummy_input)

output_names = ["image_embeddings"]

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
print(f"Exporting onnx model to {output}...")
if model_type == "vit_h":
output_dir, output_file = os.path.split(output)
os.makedirs(output_dir, mode=0o777, exist_ok=True)
torch.onnx.export(
onnx_model,
tuple(dummy_input.values()),
output,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_input.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
else:
with open(output, "wb") as f:
torch.onnx.export(
onnx_model,
tuple(dummy_input.values()),
f,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_input.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)

if onnxruntime_exists:
ort_inputs = {k: to_numpy(v) for k, v in dummy_input.items()}
providers = ["CPUExecutionProvider"]

if model_type == "vit_h":
session_option = onnxruntime.SessionOptions()
ort_session = onnxruntime.InferenceSession(output, providers=providers)
param_file = os.listdir(output_dir)
param_file.remove(output_file)
for i, layer in enumerate(param_file):
with open(os.path.join(output_dir, layer), "rb") as fp:
weights = np.frombuffer(fp.read(), dtype=np.float32)
weights = onnxruntime.OrtValue.ortvalue_from_numpy(weights)
session_option.add_initializer(layer, weights)
else:
ort_session = onnxruntime.InferenceSession(output, providers=providers)

_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")


def to_numpy(tensor):
return tensor.cpu().numpy()


if __name__ == "__main__":
args = parser.parse_args()
run_export(
model_type=args.model_type,
checkpoint=args.checkpoint,
output=args.output,
use_preprocess=args.use_preprocess,
opset=args.opset,
gelu_approximate=args.gelu_approximate,
)

if args.quantize_out is not None:
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
from onnxruntime.quantization import QuantType # type: ignore
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore

print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")
48 changes: 47 additions & 1 deletion segment_anything/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
from torch.nn import functional as F

from typing import Tuple
from typing import List, Tuple

from ..modeling import Sam
from .amg import calculate_stability_score
Expand Down Expand Up @@ -142,3 +142,49 @@ def forward(
return upscaled_masks, scores, stability_scores, areas, masks

return upscaled_masks, scores, masks


class ImageEncoderOnnxModel(nn.Module):
"""
This model should not be called directly, but is used in ONNX export.
It combines the image encoder of Sam, with some functions modified to enable
model tracing. Also supports extra options controlling what information. See
the ONNX export script for details.
"""

def __init__(
self,
model: Sam,
use_preprocess: bool,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
):
super().__init__()
self.use_preprocess = use_preprocess
self.pixel_mean = torch.tensor(pixel_mean, dtype=torch.float)
self.pixel_std = torch.tensor(pixel_std, dtype=torch.float)
self.image_encoder = model.image_encoder

@torch.no_grad()
def forward(self, input_image: torch.Tensor):
if self.use_preprocess:
input_image = self.preprocess(input_image)
image_embeddings = self.image_encoder(input_image)
return image_embeddings

def preprocess(self, x: torch.Tensor) -> torch.Tensor:
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std

# permute channels
x = torch.permute(x, (2, 0, 1))

# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))

# expand channels
x = torch.unsqueeze(x, 0)
return x