Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Update pytorch2onnx #265

Merged
merged 3 commits into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/tools_scripts.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,35 @@ python tools/publish_model.py work_dirs/example_exp/latest.pth example_model_202
```

The final output filename will be `example_model_20200202-{hash id}.pth`.

### Convert to ONNX (experimental)

We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.

```bash
python tools/pytorch2onnx.py
${CFG_PATH} \
${CHECKPOINT_PATH} \
${MODEL_TYPE} \
${IMAGE_PATH} \
--trimap-path ${TRIMAP_PATH} \
--output-file ${OUTPUT_ONNX} \
--show \
--verify \
--dynamic-export
```

Description of arguments:

- `config` : The path of a model config file.
- `checkpoint` : The path of a model checkpoint file.
- `model_type` :The model type of the config file, options: `inpainting`, `mattor`, `restorer`, `synthesizer`.
- `image_path` : path to input image file.
- `--trimap-path` : path to input trimap file, used in mattor model.
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
- `--opset-version` : ONNX opset version, default to 11.
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.

**Note**: This tool is still experimental. Some customized operators are not supported for now. And we only support `mattor` and `restorer` for now.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmedit
known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,pymatting,pytest,scipy,titlecase,torch,torchvision
known_third_party =PIL,cv2,lmdb,matplotlib,mmcv,numpy,onnx,onnxruntime,pymatting,pytest,scipy,titlecase,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
120 changes: 100 additions & 20 deletions tools/pytorch2onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import onnx
Expand All @@ -12,12 +13,28 @@
from mmedit.models import build_model


def show_result_pyplot(img, title='', block=True):
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(img, str):
img = mmcv.imread(img, channel_order='rgb')
cmap = None
if len(img.shape) == 3 and img.shape[2] == 1:
img = img.squeeze(2)
cmap = 'gray'
plt.figure()
plt.imshow(img, cmap=cmap)
plt.title(title)
plt.tight_layout()
plt.show(block=block)


def pytorch2onnx(model,
input,
model_type,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False):
verify=False,
dynamic_export=False):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.

Expand All @@ -32,85 +49,143 @@ def pytorch2onnx(model,
Default: False.
"""
model.cpu().eval()
merged = input['merged'].unsqueeze(0)
trimap = input['trimap'].unsqueeze(0)
input = torch.cat((merged, trimap), 1)

if model_type == 'mattor':
merged = input['merged'].unsqueeze(0)
trimap = input['trimap'].unsqueeze(0)
data = torch.cat((merged, trimap), 1)
elif model_type == 'restorer':
data = input['lq'].unsqueeze(0)
model.forward = model.forward_dummy
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
register_extra_symbolics(opset_version)
dynamic_axes = None
if dynamic_export:
dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
3: 'width'
},
'output': {
1: 'batch',
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
2: 'height',
3: 'width'
}
}
with torch.no_grad():
torch.onnx.export(
model,
input,
data,
output_file,
input_names=['cat_input'],
input_names=['input'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=True,
keep_initializers_as_inputs=False,
verbose=show,
opset_version=opset_version)
opset_version=opset_version,
dynamic_axes=dynamic_axes)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)

if dynamic_export:
# scale image for dynamic shape test
data = torch.nn.functional.interpolate(data, scale_factor=1.1)

# concate flip image for batch test
flip_data = data.flip(-1)
data = torch.cat((data, flip_data), 0)

# get pytorch output, only concern pred_alpha
pytorch_result = model(input)
with torch.no_grad():
pytorch_result = model(data)
if isinstance(pytorch_result, (tuple, list)):
pytorch_result = pytorch_result[0]
pytorch_result = pytorch_result.detach().numpy()
# get onnx output
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(None, {
'cat_input': input.detach().numpy(),
'input': data.detach().numpy(),
})
# only concern pred_alpha value
if isinstance(onnx_result, (tuple, list)):
onnx_result = onnx_result[0]

if show:
pytorch_visualize = pytorch_result[0].transpose(1, 2, 0)
pytorch_visualize = np.clip(pytorch_visualize, 0, 1)
onnx_visualize = onnx_result[0].transpose(1, 2, 0)
onnx_visualize = np.clip(onnx_visualize, 0, 1)

show_result_pyplot(pytorch_visualize, title='PyTorch', block=False)
show_result_pyplot(onnx_visualize, title='ONNXRuntime', block=True)

# check the numerical value
assert np.allclose(
pytorch_result,
onnx_result), 'The outputs are different between Pytorch and ONNX'
pytorch_result, onnx_result, rtol=1e-5,
atol=1e-5), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')


def parse_args():
parser = argparse.ArgumentParser(description='Convert MMediting to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'model_type',
help='what kind of model the config belong to.',
choices=['inpainting', 'mattor', 'restorer', 'synthesizer'])
parser.add_argument('img_path', help='path to input image file')
parser.add_argument('trimap_path', help='path to input trimap file')
parser.add_argument(
'--trimap-path',
default=None,
help='path to input trimap file, used in mattor model')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--dynamic-export',
action='store_true',
help='Whether to export onnx with dynamic axis.')
args = parser.parse_args()
return args


if __name__ == '__main__':
args = parse_args()
model_type = args.model_type

if model_type == 'mattor' and args.trimap_path is None:
raise ValueError('Please set `--trimap-path` to convert mattor model.')

assert args.opset_version == 11, 'MMEditing only support opset 11 now'

config = mmcv.Config.fromfile(args.config)
config.model.pretrained = None
# ONNX does not support spectral norm
if hasattr(config.model.backbone.encoder, 'with_spectral_norm'):
config.model.backbone.encoder.with_spectral_norm = False
config.model.backbone.decoder.with_spectral_norm = False
config.test_cfg.metrics = None
if model_type == 'mattor':
if hasattr(config.model.backbone.encoder, 'with_spectral_norm'):
config.model.backbone.encoder.with_spectral_norm = False
config.model.backbone.decoder.with_spectral_norm = False
config.test_cfg.metrics = None

# build the model
model = build_model(config.model, test_cfg=config.test_cfg)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')

# remove alpha from test_pipeline
keys_to_remove = ['alpha', 'ori_alpha']
if model_type == 'mattor':
keys_to_remove = ['alpha', 'ori_alpha']
elif model_type == 'restorer':
keys_to_remove = ['gt', 'gt_path']
for key in keys_to_remove:
for pipeline in list(config.test_pipeline):
if 'key' in pipeline and key == pipeline['key']:
Expand All @@ -124,14 +199,19 @@ def parse_args():
# build the data pipeline
test_pipeline = Compose(config.test_pipeline)
# prepare data
data = dict(merged_path=args.img_path, trimap_path=args.trimap_path)
if model_type == 'mattor':
data = dict(merged_path=args.img_path, trimap_path=args.trimap_path)
elif model_type == 'restorer':
data = dict(lq_path=args.img_path)
data = test_pipeline(data)

# convert model to onnx file
pytorch2onnx(
model,
data,
model_type,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify)
verify=args.verify,
dynamic_export=args.dynamic_export)