fix a bug cause: onnxruntime.capi.onnxruntime_pybind11_state.Fail #1028
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
fix a bug which may cause err: onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from *.onnx failed:Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(int64) and tensor(float) in node (/end2end/Where_1).
修复了一个会导致
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from *.onnx failed:Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(int64) and tensor(float) in node (/end2end/Where_1).
的bug。原因是在导出到onnx时,batched_labels.new_ones应该创建一个与自身相同类型的tensor(在python中运行时确实如此),然而实际导出后创建的tensor类型(float)与batched_labels(int)不一致,造成了运行时err。
在旧版本中并没有出现这个问题(暂时不清楚是由于旧版pytorch没有这个导出bug还是onnx的where能够跨类型使用)
出现这个bug的版本是:
pytorch 2.1.0a0+41361538.nv23.6+onnxruntime-gpu 1.17.0
修改前:
![图片](https://private-user-images.githubusercontent.com/73748897/312046526-8e48a52c-308c-4455-99d0-c77352b8b67e.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjMxMTI1NDYsIm5iZiI6MTcyMzExMjI0NiwicGF0aCI6Ii83Mzc0ODg5Ny8zMTIwNDY1MjYtOGU0OGE1MmMtMzA4Yy00NDU1LTk5ZDAtYzc3MzUyYjhiNjdlLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA4MDglMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwODA4VDEwMTcyNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTg2MWIxN2U1Mjc4ZWQwNmE3MDI0NjZjNzc3YWYwZjBiNjRkNTg4NWNiMWNiNzI1MzZjMTI1MGJhNTU1MTliNTkmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.8bhh4ZTH3JaJULc27D9VoV2vLzDMaCeVJJZeZicE-Yo)
修改后:
![图片](https://private-user-images.githubusercontent.com/73748897/312046644-383fa94c-263f-48e9-890f-ed476c32dd9a.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjMxMTI1NDYsIm5iZiI6MTcyMzExMjI0NiwicGF0aCI6Ii83Mzc0ODg5Ny8zMTIwNDY2NDQtMzgzZmE5NGMtMjYzZi00OGU5LTg5MGYtZWQ0NzZjMzJkZDlhLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA4MDglMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwODA4VDEwMTcyNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWYwNzdkYWJmZGI3MTYwMGNkMmYyNWZhZjQzMDgzNTc2YjIwYTEwOTZjMjQ2NGI0NTg3ZWM5MzNlNzQwNjFmYmEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.SquD6okKsCU7aA-q0XJLiKMKpDzXOAllhWDXfvefNfE)
注意:在下面2个issue中,都有类似的err,其中一条是由于seg不支持end2end导出,另一条是由于此bug而产生,虽然issue已经被关闭,然而问题的根源并没有被解决。
#1021
#1013 (comment)