Skip to content

Commit

Permalink
Update data_type.py (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
tjingrant committed Aug 17, 2018
1 parent 4d3836f commit 53e3bbb
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion onnx_tf/common/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@ def tf2onnx(dtype):
if isinstance(dtype, Number):
tf_dype = tf.as_dtype(dtype)
elif isinstance(dtype, tf.DType):
tf_dype = dtype
# Usually, tf2onnx is done via tf_type->numpy_type->onnx_type
# to leverage existing type conversion infrastructure;
# However, we need to intercept the string type early because
# lowering tf.string type to numpy dtype results in loss of
# information. <class 'object'> is returned instead of the
# numpy string type desired.
if tf_dype is tf.string:
return TensorProto.STRING
else:
tf_dype = dtype
else:
raise RuntimeError("dtype should be number or tf.DType.")
return mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(tf_dype.as_numpy_dtype)]
Expand Down

0 comments on commit 53e3bbb

Please sign in to comment.