Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Rebasing
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Aug 9, 2019
1 parent 43ce402 commit 387e909
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
6 changes: 5 additions & 1 deletion python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,10 +2059,14 @@ def convert_topk(node, **kwargs):
axis = int(attrs.get('axis', '-1'))
k = int(attrs.get('k', '1'))
ret_type = attrs.get('ret_typ')
dtype = attrs.get('dtype')
outputs = [name + '_output0']

if ret_type and ret_type == 'both':
outputs.append(name + '_output1')
if dtype and dtype == 'int64':
outputs.append(name + '_output1')
else:
raise NotImplementedError("ONNX expects indices to be of type int64")
else:
raise NotImplementedError("ONNX expects both value and indices as output")

Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,5 +784,6 @@ def lpnormalization(attrs, inputs, proto_obj):
def topk(attrs, inputs, proto_obj):
"""Returns the top k elements in an input array along the given axis."""
new_attrs = translation_utils._add_extra_attributes(attrs,
{'ret_typ': 'both'})
{'ret_typ': 'both',
'dtype': 'int64'})
return 'topk', new_attrs, inputs

0 comments on commit 387e909

Please sign in to comment.