Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export with GPU (--device opt) is not working #4159

Closed
SamSamhuns opened this issue Jul 26, 2021 · 9 comments · Fixed by #5110
Closed

ONNX export with GPU (--device opt) is not working #4159

SamSamhuns opened this issue Jul 26, 2021 · 9 comments · Fixed by #5110
Labels
bug Something isn't working Stale

Comments

@SamSamhuns
Copy link
Contributor

🐛 Bug

I get an Expected all tensors to be on the same device error when using $ python export.py --weight yolov5m.pt --include onnx --device 0

To Reproduce (REQUIRED)

Input:

# download yolov5m.pt model first
$ python export.py --weight yolov5m.pt --include onnx --device 0

Output:

Traceback (most recent call last):
  File "export.py", line 54, in export_onnx
    torch.onnx.export(model, img, f, verbose=False, opset_version=11,
  File "/home/sam/human_body_proportion_estimation/yolov5/venv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 275, in export

    return utils.export(model, args, f, export_params, verbose, training,
  File "/home/sam/human_body_proportion_estimation/yolov5/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/home/sam/human_body_proportion_estimation/yolov5/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 689, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/home/sam/human_body_proportion_estimation/yolov5/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 501, in _model_to_graph
    params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument index in method wrapper_index_select)

Expected behavior

Successful onnx export

Environment

If applicable, add screenshots to help explain your problem.

  • OS: [Ubuntu]
  • GPU [Tesla V100]

Additional context

I also checked if the model and the img passed to the onnx export are present in the same cuda device, which they were.

@SamSamhuns SamSamhuns added the bug Something isn't working label Jul 26, 2021
@glenn-jocher
Copy link
Member

glenn-jocher commented Jul 26, 2021

@SamSamhuns yes this has been reported before. ONNX models must be exported on CPU device for now. If you determine the cause of the issue please submit a PR to help other users, thank you!

@SamSamhuns
Copy link
Contributor Author

Found a way to avoid generating an error while exporting with GPU but not sure whether it's worth a PR @glenn-jocher

It seems the error occurs when using the do_constant_folding parameter with the onnx export call, which leads to an issue in line 501 in torch.onnx.utils.py params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,_export_onnx_opset_version).

Unfortunately, even after verifying that each model parameter was on the correct cuda device, the error still persists.

However, GPU export is possible with the following but disabling constant_folding might cause computational penalties

torch.onnx.export(model, img, f, verbose=False, opset_version=opset,
                  training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
                  do_constant_folding=(not train) and (not next(model.parameters()).is_cuda),  # Additional check if cuda used
                  input_names=['images'],
                  output_names=['output'],
                  dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # shape(1,3,640,640)
                                'output': {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)
                                    } if dynamic else None)

Note: Setting model.model[-1].export = True as I saw in some issues did not solve the issue either.

@glenn-jocher
Copy link
Member

@SamSamhuns can you quantify the penalty, i.e. extra layers or difference in parameters between the two export methods, or profiling results when running python detect.py --weights yolov5s.onnx?

@SamSamhuns
Copy link
Contributor Author

It seems in this case, the onnx model regardless of export with cpu/gpu have the same performance in terms of accuracy and speed on a cursory glance.

So that additional (not next(model.parameters()).is_cuda) can be a added a temporary check to avoid errors but should not be a long term solution.

ONNX cpu export ONNX gpu export
bus cpu onnx bus gpu onnx

However, when using the useful --half option when doing the onnx export with gpu, the export is complete but the half model fails on inference, unfortunately.

Traceback (most recent call last):
  File "detect.py", line 243, in <module>
    main(opt)

  File "detect.py", line 238, in main
    run(**vars(opt))
  File "/home/sam/yolov5/venv/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "detect.py", line 82, in run
    session = onnxruntime.InferenceSession(w, None)
  File "/home/sam/yolov5/venv/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 283, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/sam/yolov5/venv/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 310, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from yolov5s_gpu_half.onnx failed:Type Error: Type parameter (T) of Optype (Concat) b
ound to different types (tensor(float) and tensor(float16) in node (Concat_540).

For some reason, a node has float32 format despite the model being exported in float16. So there are still some issues with GPU onnx export

@glenn-jocher
Copy link
Member

@SamSamhuns there's really no reason to export on GPU other than to produce an FP16 model. FP16 models don't run in PyTorch on CPU as the PyTorch backend instruction sets are not capable of handling this, I don't know about ONNX.

@SamSamhuns
Copy link
Contributor Author

Makes sense, anyway there is some underlying issue in ONNX or pytorch onnx export that is causing this.

@github-actions
Copy link
Contributor

github-actions bot commented Aug 27, 2021

👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs.

Access additional YOLOv5 🚀 resources:

Access additional Ultralytics ⚡ resources:

Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed!

Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐!

@LaserLV52
Copy link

It seems in this case, the onnx model regardless of export with cpu/gpu have the same performance in terms of accuracy and speed on a cursory glance.

So that additional (not next(model.parameters()).is_cuda) can be a added a temporary check to avoid errors but should not be a long term solution.

ONNX cpu export ONNX gpu export
bus cpu onnx bus gpu onnx
However, when using the useful --half option when doing the onnx export with gpu, the export is complete but the half model fails on inference, unfortunately.

Traceback (most recent call last):
  File "detect.py", line 243, in <module>
    main(opt)

  File "detect.py", line 238, in main
    run(**vars(opt))
  File "/home/sam/yolov5/venv/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "detect.py", line 82, in run
    session = onnxruntime.InferenceSession(w, None)
  File "/home/sam/yolov5/venv/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 283, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/sam/yolov5/venv/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 310, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from yolov5s_gpu_half.onnx failed:Type Error: Type parameter (T) of Optype (Concat) b
ound to different types (tensor(float) and tensor(float16) in node (Concat_540).

For some reason, a node has float32 format despite the model being exported in float16. So there are still some issues with GPU onnx export

Hello, I follow your solution by changed the do_constant_folding= not train to do_constant_folding=(not train) and (not next(model.parameters()).is_cuda). Then, I use the --device 0 to export the onnx, which didn't arise the error. But, I used the onnx file in detect.py, I found that the model is still work on CPU not GPU, do you have any idea? Thanks!

@glenn-jocher
Copy link
Member

@SamSamhuns @LaserLV52 good news 😃! Your original issue may now be fixed ✅ in PR #5110 by @SamFC10. This PR implements backend-device change improvements to allow for YOLOv5 models to be exportedto ONNX on either GPU or CPU, and to export at FP16 with the --half flag on GPU --device 0.

To receive this update:

  • Gitgit pull from within your yolov5/ directory or git clone https://github.com/ultralytics/yolov5 again
  • PyTorch Hub – Force-reload with model = torch.hub.load('ultralytics/yolov5', 'yolov5s', force_reload=True)
  • Notebooks – View updated notebooks Open In Colab Open In Kaggle
  • Dockersudo docker pull ultralytics/yolov5:latest to update your image Docker Pulls

Thank you for spotting this issue and informing us of the problem. Please let us know if this update resolves the issue for you, and feel free to inform us of any other issues you discover or feature requests that come to mind. Happy trainings with YOLOv5 🚀!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Stale
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants