Skip to content

Commit

Permalink
Support ver 7 (#193)
Browse files Browse the repository at this point in the history
* bug fix and some improvements

* support ver 7 and tile

* yapf

* support bias_add and improve coverage getter
  • Loading branch information
fumihwh committed May 30, 2018
1 parent cd8a48b commit 516e986
Show file tree
Hide file tree
Showing 22 changed files with 150 additions and 61 deletions.
30 changes: 15 additions & 15 deletions doc/support_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,52 +130,52 @@ ______
| -------------- |:------------------:|
|abs|1, 6|
|acos|7|
|add|1, 6|
|add|1, 6, 7|
|add_n|1, 6|
|arg_max|1|
|arg_min|1|
|asin|7|
|atan|7|
|avg_pool|1, 7|
|batch_norm|1, 6|
|bias_add|1, 6|
|batch_norm|1, 6, 7|
|bias_add|1, 6, 7|
|cast|1, 6|
|ceil|1, 6|
|concat_v2|1, 4|
|conv1_d|1|
|conv2_d|1|
|conv3_d|1|
|cos|7|
|equal|1|
|equal|1, 7|
|exp|1, 6|
|expand_dims|1|
|fill|1|
|floor|1, 6|
|fused_batch_norm|1, 6|
|greater|1|
|fused_batch_norm|1, 6, 7|
|greater|1, 7|
|identity|1|
|less|1|
|less|1, 7|
|log|1, 6|
|log_softmax|1|
|logical_and|1|
|logical_and|1, 7|
|logical_not|1|
|logical_or|1|
|logical_xor|1|
|logical_or|1, 7|
|logical_xor|1, 7|
|mat_mul|1|
|max|1|
|max_pool|1|
|maximum|1, 6|
|mean|1|
|min|1|
|minimum|1, 6|
|mul|1, 6|
|mul|1, 6, 7|
|neg|1, 6|
|pad|1, 2|
|pow|1|
|pow|1, 7|
|prod|1|
|random_standard_normal|1|
|random_uniform|1|
|real_div|1, 6|
|real_div|1, 6, 7|
|reciprocal|1, 6|
|relu|1, 6|
|reshape|1, 5|
Expand All @@ -190,10 +190,10 @@ ______
|split_v|1, 2|
|sqrt|1, 6|
|squeeze|1|
|sub|1, 6|
|sub|1, 6, 7|
|sum|1|
|tan|7|
|tanh|1, 6|
|tile|1|
|tile|6|
|top_k_v2|1|
|transpose|1|
11 changes: 9 additions & 2 deletions onnx_tf/common/handler_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def get_frontend_coverage():
:return: onnx_coverage: e.g. {'domain': {'ONNX_OP': [versions], ...}, ...}
tf_coverage: e.g. {'domain': {'TF_OP': [versions], ...}, ...}
"""

def _update_coverage(coverage, domain, key, versions):
domain_coverage = coverage.setdefault(domain, {})
vers = domain_coverage.get(key, [])
vers.extend(versions)
domain_coverage[key] = sorted(list(set(vers)))

tf_coverage = {}
onnx_coverage = {}
for handler in FrontendHandler.__subclasses__():
Expand All @@ -54,6 +61,6 @@ def get_frontend_coverage():
versions = handler.get_versions()
domain = handler.DOMAIN
for tf_op in handler.TF_OP:
tf_coverage.setdefault(domain, {})[op_name_to_lower(tf_op)] = versions
onnx_coverage.setdefault(domain, {})[handler.ONNX_OP] = versions
_update_coverage(tf_coverage, domain, op_name_to_lower(tf_op), versions)
_update_coverage(onnx_coverage, domain, handler.ONNX_OP, versions)
return onnx_coverage, tf_coverage
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def version_1(cls, node, **kwargs):
@classmethod
def version_6(cls, node, **kwargs):
return cls.arithmetic_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.arithmetic_op(node, **kwargs)
3 changes: 1 addition & 2 deletions onnx_tf/handlers/frontend/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@ def args_check(cls, node, **kwargs):
@classmethod
def version_1(cls, node, **kwargs):
axis = np.asscalar(kwargs["consts"][node.inputs[1]])
return cls.make_node(
node, [node.inputs[0]], axis=axis, keepdims=0)
return cls.make_node(node, [node.inputs[0]], axis=axis, keepdims=0)
3 changes: 1 addition & 2 deletions onnx_tf/handlers/frontend/argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@ def args_check(cls, node, **kwargs):
@classmethod
def version_1(cls, node, **kwargs):
axis = np.asscalar(kwargs["consts"][node.inputs[1]])
return cls.make_node(
node, [node.inputs[0]], axis=axis, keepdims=0)
return cls.make_node(node, [node.inputs[0]], axis=axis, keepdims=0)
6 changes: 6 additions & 0 deletions onnx_tf/handlers/frontend/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ def version_6(cls, node, **kwargs):
node,
epsilon=node.attr.get("epsilon", 1e-5),
is_test=node.attr.get("is_training", 0))

@classmethod
def version_7(cls, node, **kwargs):
return cls.make_node(
node,
epsilon=node.attr.get("epsilon", 1e-5))
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/bias_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ def version_6(cls, node, **kwargs):
channel_first = data_format[1] == "C"
axis = 1 if channel_first else -1
return cls.arithmetic_op(node, axis=axis, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.arithmetic_op(node, **kwargs)
38 changes: 29 additions & 9 deletions onnx_tf/handlers/frontend/control_flow_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
class LogicalMixin(object):

@classmethod
def logical_op(cls, node, broadcast=1, **kwargs):
def logical_op(cls, node, **kwargs):
if cls.SINCE_VERSION <= 6:
return cls._limited_broadcast(node, **kwargs)
else: # since_version >= 7
return cls._np_broadcast(node, **kwargs)

@classmethod
def _limited_broadcast(cls, node, broadcast=1, **kwargs):
ex_kwargs = {}
if broadcast == 1:
ex_kwargs["broadcast"] = 1
Expand All @@ -14,16 +21,29 @@ def logical_op(cls, node, broadcast=1, **kwargs):
ex_kwargs["axis"] = axis
return cls.make_node(node, **ex_kwargs)

@classmethod
def _np_broadcast(cls, node, **kwargs):
return cls.make_node(node)


class ComparisonMixin(object):

@classmethod
def comparison_op(cls, node, broadcast=1, **kwargs):
ex_kwargs = {}
if broadcast == 1:
ex_kwargs["broadcast"] = 1
node_dict = kwargs["node_dict"]
axis = get_broadcast_axis(*[node_dict[x] for x in node.inputs[0:2]])
if axis is not None:
ex_kwargs["axis"] = axis
def comparison_op(cls, node, **kwargs):
if cls.SINCE_VERSION <= 6:
return cls._limited_broadcast(node, **kwargs)
else: # since_version >= 7
return cls._np_broadcast(node, **kwargs)

@classmethod
def _limited_broadcast(cls, node, **kwargs):
ex_kwargs = {"broadcast": 1}
node_dict = kwargs["node_dict"]
axis = get_broadcast_axis(*[node_dict[x] for x in node.inputs[0:2]])
if axis is not None:
ex_kwargs["axis"] = axis
return cls.make_node(node, **ex_kwargs)

@classmethod
def _np_broadcast(cls, node, **kwargs):
return cls.make_node(node)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/div.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def version_1(cls, node, **kwargs):
@classmethod
def version_6(cls, node, **kwargs):
return cls.arithmetic_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.arithmetic_op(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class Equal(ComparisonMixin, FrontendHandler):
@classmethod
def version_1(cls, node, **kwargs):
return cls.comparison_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.comparison_op(node, **kwargs)
3 changes: 1 addition & 2 deletions onnx_tf/handlers/frontend/fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@ def args_check(cls, node, **kwargs):
@classmethod
def version_1(cls, node, **kwargs):
value = float(np.asscalar(kwargs["consts"][node.inputs[1]]))
return cls.make_node(
node, [node.inputs[0]], input_as_shape=1, value=value)
return cls.make_node(node, [node.inputs[0]], input_as_shape=1, value=value)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/greater.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class Greater(ComparisonMixin, FrontendHandler):
@classmethod
def version_1(cls, node, **kwargs):
return cls.comparison_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.comparison_op(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/less.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class Less(ComparisonMixin, FrontendHandler):
@classmethod
def version_1(cls, node, **kwargs):
return cls.comparison_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.comparison_op(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/logical_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class LogicalAnd(LogicalMixin, FrontendHandler):
@classmethod
def version_1(cls, node, **kwargs):
return cls.logical_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.logical_op(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/logical_or.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class LogicalOr(LogicalMixin, FrontendHandler):
@classmethod
def version_1(cls, node, **kwargs):
return cls.logical_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.logical_op(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/logical_xor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class LogicalXor(LogicalMixin, FrontendHandler):
@classmethod
def version_1(cls, node, **kwargs):
return cls.logical_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.logical_op(node, **kwargs)
11 changes: 11 additions & 0 deletions onnx_tf/handlers/frontend/math_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ class ArithmeticMixin(object):

@classmethod
def arithmetic_op(cls, node, **kwargs):
if cls.SINCE_VERSION <= 6:
return cls._limited_broadcast(node, **kwargs)
else: # since_version >= 7
return cls._np_broadcast(node, **kwargs)

@classmethod
def _limited_broadcast(cls, node, **kwargs):
node_dict = kwargs["node_dict"]
axis = kwargs.get(
"axis", get_broadcast_axis(*[node_dict[x] for x in node.inputs[0:2]]))
Expand All @@ -21,6 +28,10 @@ def arithmetic_op(cls, node, **kwargs):
ex_kwargs["axis"] = axis
return cls.make_node(node, broadcast=1, **ex_kwargs)

@classmethod
def _np_broadcast(cls, node, **kwargs):
return cls.make_node(node)


class ReductionMixin(object):

Expand Down
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def version_1(cls, node, **kwargs):
@classmethod
def version_6(cls, node, **kwargs):
return cls.arithmetic_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.arithmetic_op(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class Pow(BasicMathMixin, FrontendHandler):
@classmethod
def version_1(cls, node, **kwargs):
return cls.basic_math_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.basic_math_op(node, **kwargs)
4 changes: 4 additions & 0 deletions onnx_tf/handlers/frontend/subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ def version_1(cls, node, **kwargs):
@classmethod
def version_6(cls, node, **kwargs):
return cls.arithmetic_op(node, **kwargs)

@classmethod
def version_7(cls, node, **kwargs):
return cls.arithmetic_op(node, **kwargs)
2 changes: 1 addition & 1 deletion onnx_tf/handlers/frontend/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Tile(FrontendHandler):

@classmethod
def version_1(cls, node, **kwargs):
def version_6(cls, node, **kwargs):
data_type_cast_map = kwargs["data_type_cast_map"]
data_type_cast_map[node.inputs[1]] = TensorProto.INT64
return cls.make_node(node)
Loading

0 comments on commit 516e986

Please sign in to comment.