Skip to content

Commit

Permalink
Add SpaceToDepth in frontend (#183)
Browse files Browse the repository at this point in the history
* Add SpaceToDepth

* NHWC space to depth

* Refactoring. Move get_perm_from_formats() to common.py
  • Loading branch information
ruimashita authored and tjingrant committed May 21, 2018
1 parent 02f5001 commit cd3768e
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 17 deletions.
4 changes: 0 additions & 4 deletions onnx_tf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,6 @@ def get_data_format(cls, x_rank, support_cuda):
compute_format = "N" + sp_dim_string + "C"
return storage_format, compute_format

@classmethod
def get_perm_from_formats(cls, _from, _to):
return list(map(lambda x: _from.find(x), _to))

@classmethod
def supports_device(cls, device):
if device == "CUDA":
Expand Down
25 changes: 13 additions & 12 deletions onnx_tf/backends/backend_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ONNX_OP_TO_TF_OP,
ONNX_TYPE_TO_TF_TYPE,
PAD_TF_INCOMPATIBLE,
get_perm_from_formats,
)
import onnx.numpy_helper
import onnx.defs
Expand Down Expand Up @@ -227,7 +228,7 @@ def _pool(cls, node, input_dict, pool_func, pooling_type):
data_format=compute_format)
else:
x = tf.transpose(
x, perm=cls.get_perm_from_formats(storage_format, compute_format))
x, perm=get_perm_from_formats(storage_format, compute_format))
pooled = pool_func(
x,
kernel_shape,
Expand All @@ -236,7 +237,7 @@ def _pool(cls, node, input_dict, pool_func, pooling_type):
data_format=compute_format)
pooled = tf.transpose(
pooled,
perm=cls.get_perm_from_formats(compute_format, storage_format))
perm=get_perm_from_formats(compute_format, storage_format))

return [pooled]

Expand Down Expand Up @@ -393,7 +394,7 @@ def _conv(cls, node, input_dict, transpose=False):
xs = tf.split(x, num_or_size_splits=group, axis=1)
else:
x = tf.transpose(
x, perm=cls.get_perm_from_formats(storage_format, compute_format))
x, perm=get_perm_from_formats(storage_format, compute_format))
xs = tf.split(x, num_or_size_splits=group, axis=-1)

if transpose:
Expand Down Expand Up @@ -485,7 +486,7 @@ def _conv(cls, node, input_dict, transpose=False):
output = tf.concat(convolved, axis=-1)
output = tf.transpose(
output,
perm=cls.get_perm_from_formats(compute_format, storage_format))
perm=get_perm_from_formats(compute_format, storage_format))
else:
bias = input_dict[node.inputs[2]]
bias = cls._explicit_broadcast(
Expand All @@ -499,7 +500,7 @@ def _conv(cls, node, input_dict, transpose=False):
output = tf.add(output, bias)
output = tf.transpose(
output,
perm=cls.get_perm_from_formats(compute_format, storage_format))
perm=get_perm_from_formats(compute_format, storage_format))

return [output]

Expand All @@ -522,11 +523,11 @@ def handle_depth_to_space(cls, node, input_dict):
x, block_size=node.attrs["blocksize"], data_format=compute_format)
else:
x = tf.transpose(
x, perm=cls.get_perm_from_formats(storage_format, compute_format))
x, perm=get_perm_from_formats(storage_format, compute_format))
y = tf.depth_to_space(
x, block_size=node.attrs["blocksize"], data_format=compute_format)
y = tf.transpose(
y, perm=cls.get_perm_from_formats(compute_format, storage_format))
y, perm=get_perm_from_formats(compute_format, storage_format))
return [y]

@classmethod
Expand Down Expand Up @@ -966,11 +967,11 @@ def handle_space_to_depth(cls, node, input_dict):
x, block_size=node.attrs["blocksize"], data_format=compute_format)
else:
x = tf.transpose(
x, perm=cls.get_perm_from_formats(storage_format, compute_format))
x, perm=get_perm_from_formats(storage_format, compute_format))
y = tf.space_to_depth(
x, block_size=node.attrs["blocksize"], data_format=compute_format)
y = tf.transpose(
y, perm=cls.get_perm_from_formats(compute_format, storage_format))
y, perm=get_perm_from_formats(compute_format, storage_format))
return [y]

@classmethod
Expand Down Expand Up @@ -1055,9 +1056,9 @@ def handle_upsample(cls, node, input_dict):
tf.transpose(
tf.image.resize_images(
tf.transpose(
x, perm=cls.get_perm_from_formats(storage_format,
"NHWC")), size, method),
perm=cls.get_perm_from_formats("NHWC", storage_format))
x, perm=get_perm_from_formats(storage_format,
"NHWC")), size, method),
perm=get_perm_from_formats("NHWC", storage_format))
]

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions onnx_tf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ def get_list_value(attr):
def get_unique_suffix():
return str(uuid.uuid4())[:8]


def get_perm_from_formats(_from, _to):
return list(map(lambda x: _from.find(x), _to))


# Constant string used to indicate that requested padding
# is not natively supported in Tensorflow.
PAD_TF_INCOMPATIBLE = "PAD_TF_INCOMPATIBLE"
1 change: 1 addition & 0 deletions onnx_tf/doc/support_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ ______
|shape|1|
|sigmoid|1, 6|
|softmax|1|
|space_to_depth|1|
|split_v|1, 2|
|sqrt|1, 6|
|squeeze|1|
Expand Down
34 changes: 34 additions & 0 deletions onnx_tf/frontends/frontend_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from onnx import TensorProto
from onnx_tf.common import as_dtype
from onnx_tf.common import get_unique_suffix
from onnx_tf.common import get_perm_from_formats
from onnx_tf.common import TF_TYPE_TO_ONNX_TYPE
from onnx_tf.frontend import TensorflowFrontendBase

Expand Down Expand Up @@ -224,6 +225,39 @@ def handle_reshape(cls, node, **kwargs):
return helper.make_node(
"Reshape", [node.inputs[0]], [node.name], shape=shape)

@classmethod
@register_onnx_op("SpaceToDepth")
def handle_space_to_depth(cls, node, **kwargs):
blocksize = node.attr["block_size"]
data_format = node.attr.get("data_format", "NHWC").decode()

assert data_format in ["NHWC", "NCHW"], \
("data format {} should be in ['NCHW', 'NHWC'].".format(data_format))

if data_format == "NHWC":
transpose_unique_suffix = get_unique_suffix()
space_to_depth_unique_suffix = get_unique_suffix()
transpose_name = node.inputs[0] + "_T_" + transpose_unique_suffix
space_to_depth_name = node.inputs[0] + "_T_STD_" + space_to_depth_unique_suffix
before_transpose_node = helper.make_node(
"Transpose",
[node.inputs[0]], [transpose_name],
perm=get_perm_from_formats("NHWC", "NCHW"))
space_to_depth_node = helper.make_node(
"SpaceToDepth",
[transpose_name], [space_to_depth_name], blocksize=blocksize)

after_transpose_node = helper.make_node(
"Transpose",
[space_to_depth_name], [node.name],
perm=get_perm_from_formats("NCHW", "NHWC"))

return [before_transpose_node, space_to_depth_node, after_transpose_node]

if data_format == "NCHW":
return helper.make_node(
"SpaceToDepth", [node.inputs[0]], [node.name], blocksize=blocksize)

@classmethod
@register_onnx_op("Split")
def handle_split_v(cls, node, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion onnx_tf/opset_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@
'Softmax': [1],
'Softplus': [],
'Softsign': [],
'SpaceToDepth': [],
'SpaceToDepth': [1],
'Split': [1, 2],
'Sqrt': [1, 6],
'Squeeze': [1],
Expand Down Expand Up @@ -277,6 +277,7 @@
'shape': [1],
'sigmoid': [1, 6],
'softmax': [1],
'space_to_depth': [1],
'split_v': [1, 2],
'sqrt': [1, 6],
'squeeze': [1],
Expand Down
1 change: 1 addition & 0 deletions test/frontend/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def do_test_expected(self):
("test_reshape", tf.reshape, "Reshape", [get_rnd([10, 10]), [4, 25]], {}),
("test_shape", tf.shape, "Shape", [get_rnd([1, 2, 3, 4])], {}),
("test_sigmoid", tf.sigmoid, "Sigmoid", [get_rnd([10, 10])], {}),
("test_space_to_depth", tf.space_to_depth, "SpaceToDepth", [get_rnd([2, 8, 8, 5])], {"block_size": 2}),
("test_split", tf.split, "split", [get_rnd([10, 10]), [2, 3, 5]], {}),
("test_sqrt", tf.sqrt, "Sqrt", [get_rnd([10, 10])], {}),
("test_squeeze", tf.squeeze, "Squeeze", [get_rnd([1, 1, 10, 10])], {"axis":[0, 1]}),
Expand Down

0 comments on commit cd3768e

Please sign in to comment.