Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Demo can be exported to ONNX but other pretrained models cannot #49

Open
ShubhamJain7 opened this issue Jun 4, 2020 · 9 comments
Open
Labels
enhancement New feature or request

Comments

@ShubhamJain7
Copy link

Instructions To Reproduce the Issue:

run torch.onnx.export on the demo model provided here and on a model from torchhub. The demo model is successfully exported while other models fail.

#works
torch.onnx.export(detr_demo, sample_input, 'detr_demo.onnx', opset_version = 10)

#does not work
detr = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
detr.eval()
torch.onnx.export(detr, sample_input, 'detr.onnx', opset_version = 10)

see full code here

The error log is as follows:

/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:59: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:60: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
/usr/local/lib/python3.6/dist-packages/torch/tensor.py:467: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  'incorrect results).', category=RuntimeWarning)
/root/.cache/torch/hub/facebookresearch_detr_master/util/misc.py:294: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  batch_shape = (len(tensor_list),) + max_size
/root/.cache/torch/hub/facebookresearch_detr_master/util/misc.py:301: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
/root/.cache/torch/hub/facebookresearch_detr_master/util/misc.py:302: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  m[: img.shape[1], :img.shape[2]] = False

---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

<ipython-input-19-968e97398387> in <module>()
     11 
     12 torch.onnx.export(detr_demo, sample_input, 'detr_demo.onnx', opset_version = 10)
---> 13 torch.onnx.export(detr, sample_input, 'detr.onnx', opset_version = 10)

8 frames

/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    166                         do_constant_folding, example_outputs,
    167                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
--> 168                         custom_opsets, enable_onnx_checker, use_external_data_format)
    169 
    170 

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     67             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
     68             custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker,
---> 69             use_external_data_format=use_external_data_format)
     70 
     71 

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format)
    486                                                         example_outputs, propagate,
    487                                                         _retain_param_name, val_do_constant_folding,
--> 488                                                         fixed_batch_size=fixed_batch_size)
    489 
    490         # TODO: Don't allocate a in-memory string for the protobuf

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size)
    349     graph = _optimize_graph(graph, operator_export_type,
    350                             _disable_torch_constant_prop=_disable_torch_constant_prop,
--> 351                             fixed_batch_size=fixed_batch_size, params_dict=params_dict)
    352 
    353     if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict)
    152         torch._C._jit_pass_erase_number_types(graph)
    153 
--> 154         graph = torch._C._jit_pass_onnx(graph, operator_export_type)
    155         torch._C._jit_pass_lint(graph)
    156 

/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py in _run_symbolic_function(*args, **kwargs)
    197 def _run_symbolic_function(*args, **kwargs):
    198     from torch.onnx import utils
--> 199     return utils._run_symbolic_function(*args, **kwargs)
    200 
    201 

/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py in _run_symbolic_function(g, n, inputs, env, operator_export_type)
    738                                   .format(op_name, opset_version, op_name))
    739                 op_fn = sym_registry.get_registered_op(op_name, '', opset_version)
--> 740                 return op_fn(g, *inputs, **attrs)
    741 
    742         elif ns == "prim":

/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_helper.py in wrapper(g, *args)
    127             assert len(arg_descriptors) >= len(args)
    128             args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
--> 129             return fn(g, *args)
    130         # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
    131         try:

/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_opset9.py in ones(g, sizes, dtype, layout, device, pin_memory)
   1409         dtype = 6  # float
   1410     return g.op("ConstantOfShape", sizes,
-> 1411                 value_t=torch.tensor([1], dtype=sym_help.scalar_type_to_pytorch_type[dtype]))
   1412 
   1413 

IndexError: list index out of range

Expected behavior:

It should be possible to export a model from torchhub similar to the demo model.

Environment:

Google colab

@fmassa
Copy link
Contributor

fmassa commented Jun 4, 2020

Hi,

Thanks for your request.
Adding support for torchscript was on our TODO list, so I've just sent a PR making the torchhub models support torchscript #51

We don't currently have plans to add support for ONNX though, but we would welcome PRs making it work (as long as the code still stays readable and simple to understand).

@fmassa fmassa added the enhancement New feature or request label Jun 4, 2020
@pfeatherstone
Copy link

ONNX export would be really useful. Unless there is a way of converting torchscript to ONNX.

@ShubhamJain7
Copy link
Author

@fmassa
Thanks for the link to your PR!
I'll look into adding ONNX support.

@ShubhamJain7 ShubhamJain7 changed the title Demo can be exported to ONNX but other pretrained models Demo can be exported to ONNX but other pretrained models cannot Jun 25, 2020
@zhiqwang
Copy link
Contributor

zhiqwang commented Jul 24, 2020

The model detr_resnet50 can be converted into onnx, but its output is all nan, I tested NestedTensor in another project, its inference result in onnxruntime is also different compare to pytorch's. After removing NestedTensor, its result is equal to pytorch's. I guess the NestedTensor here cause this problem. Is there some suggestion to fix it?

@zhiqwang
Copy link
Contributor

zhiqwang commented Jul 27, 2020

Hi all, I think I have fixes the above problem I mentioned. I retrieved the code in torchvision's faster-rcnn repo, and I found the purpose of this function is aimed at resolving this problem. My fix is here, I will test if it can help to resolve the onnx inference problem in DETR.

@zhiqwang
Copy link
Contributor

zhiqwang commented Jul 28, 2020

Hi @fmassa

I take a quick fix of nested_tensor_from_tensor_list() according to the implementation of torchvision's to make it supported by ONNX tracing, Now the inference results in onnxruntime is consistent with pytorch. And here is my modification of _onnx_nested_tensor_from_tensor_list().

Can you check my modification, if it is OK, I would like to submit a PR :)

@fmassa
Copy link
Contributor

fmassa commented Aug 1, 2020

@zhiqwang your implementation looks good, can you send a PR?

@zhiqwang
Copy link
Contributor

zhiqwang commented Aug 2, 2020

Hi @fmassa , Of course, It's my pleasure.

@kopyl
Copy link

kopyl commented Jun 14, 2023

So it it supported now?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants