diff --git a/onnx_tf/backend.py b/onnx_tf/backend.py index 3c1d67a63..ade20adef 100644 --- a/onnx_tf/backend.py +++ b/onnx_tf/backend.py @@ -473,10 +473,6 @@ def get_data_format(cls, x_rank, support_cuda): compute_format = "N" + sp_dim_string + "C" return storage_format, compute_format - @classmethod - def get_perm_from_formats(cls, _from, _to): - return list(map(lambda x: _from.find(x), _to)) - @classmethod def supports_device(cls, device): if device == "CUDA": diff --git a/onnx_tf/backends/backend_v1.py b/onnx_tf/backends/backend_v1.py index b91f67245..871a9602c 100644 --- a/onnx_tf/backends/backend_v1.py +++ b/onnx_tf/backends/backend_v1.py @@ -25,6 +25,7 @@ ONNX_OP_TO_TF_OP, ONNX_TYPE_TO_TF_TYPE, PAD_TF_INCOMPATIBLE, + get_perm_from_formats, ) import onnx.numpy_helper import onnx.defs @@ -227,7 +228,7 @@ def _pool(cls, node, input_dict, pool_func, pooling_type): data_format=compute_format) else: x = tf.transpose( - x, perm=cls.get_perm_from_formats(storage_format, compute_format)) + x, perm=get_perm_from_formats(storage_format, compute_format)) pooled = pool_func( x, kernel_shape, @@ -236,7 +237,7 @@ def _pool(cls, node, input_dict, pool_func, pooling_type): data_format=compute_format) pooled = tf.transpose( pooled, - perm=cls.get_perm_from_formats(compute_format, storage_format)) + perm=get_perm_from_formats(compute_format, storage_format)) return [pooled] @@ -393,7 +394,7 @@ def _conv(cls, node, input_dict, transpose=False): xs = tf.split(x, num_or_size_splits=group, axis=1) else: x = tf.transpose( - x, perm=cls.get_perm_from_formats(storage_format, compute_format)) + x, perm=get_perm_from_formats(storage_format, compute_format)) xs = tf.split(x, num_or_size_splits=group, axis=-1) if transpose: @@ -485,7 +486,7 @@ def _conv(cls, node, input_dict, transpose=False): output = tf.concat(convolved, axis=-1) output = tf.transpose( output, - perm=cls.get_perm_from_formats(compute_format, storage_format)) + perm=get_perm_from_formats(compute_format, storage_format)) else: bias = input_dict[node.inputs[2]] bias = cls._explicit_broadcast( @@ -499,7 +500,7 @@ def _conv(cls, node, input_dict, transpose=False): output = tf.add(output, bias) output = tf.transpose( output, - perm=cls.get_perm_from_formats(compute_format, storage_format)) + perm=get_perm_from_formats(compute_format, storage_format)) return [output] @@ -522,11 +523,11 @@ def handle_depth_to_space(cls, node, input_dict): x, block_size=node.attrs["blocksize"], data_format=compute_format) else: x = tf.transpose( - x, perm=cls.get_perm_from_formats(storage_format, compute_format)) + x, perm=get_perm_from_formats(storage_format, compute_format)) y = tf.depth_to_space( x, block_size=node.attrs["blocksize"], data_format=compute_format) y = tf.transpose( - y, perm=cls.get_perm_from_formats(compute_format, storage_format)) + y, perm=get_perm_from_formats(compute_format, storage_format)) return [y] @classmethod @@ -966,11 +967,11 @@ def handle_space_to_depth(cls, node, input_dict): x, block_size=node.attrs["blocksize"], data_format=compute_format) else: x = tf.transpose( - x, perm=cls.get_perm_from_formats(storage_format, compute_format)) + x, perm=get_perm_from_formats(storage_format, compute_format)) y = tf.space_to_depth( x, block_size=node.attrs["blocksize"], data_format=compute_format) y = tf.transpose( - y, perm=cls.get_perm_from_formats(compute_format, storage_format)) + y, perm=get_perm_from_formats(compute_format, storage_format)) return [y] @classmethod @@ -1055,9 +1056,9 @@ def handle_upsample(cls, node, input_dict): tf.transpose( tf.image.resize_images( tf.transpose( - x, perm=cls.get_perm_from_formats(storage_format, - "NHWC")), size, method), - perm=cls.get_perm_from_formats("NHWC", storage_format)) + x, perm=get_perm_from_formats(storage_format, + "NHWC")), size, method), + perm=get_perm_from_formats("NHWC", storage_format)) ] @classmethod diff --git a/onnx_tf/common.py b/onnx_tf/common.py index 9c0b1e2fd..31031db63 100644 --- a/onnx_tf/common.py +++ b/onnx_tf/common.py @@ -263,6 +263,11 @@ def get_list_value(attr): def get_unique_suffix(): return str(uuid.uuid4())[:8] + +def get_perm_from_formats(_from, _to): + return list(map(lambda x: _from.find(x), _to)) + + # Constant string used to indicate that requested padding # is not natively supported in Tensorflow. PAD_TF_INCOMPATIBLE = "PAD_TF_INCOMPATIBLE" diff --git a/onnx_tf/doc/support_status.md b/onnx_tf/doc/support_status.md index a9b65559b..c45d3b52c 100644 --- a/onnx_tf/doc/support_status.md +++ b/onnx_tf/doc/support_status.md @@ -171,6 +171,7 @@ ______ |shape|1| |sigmoid|1, 6| |softmax|1| +|space_to_depth|1| |split_v|1, 2| |sqrt|1, 6| |squeeze|1| diff --git a/onnx_tf/frontends/frontend_v1.py b/onnx_tf/frontends/frontend_v1.py index 4df1b6d9c..e3f161f9d 100644 --- a/onnx_tf/frontends/frontend_v1.py +++ b/onnx_tf/frontends/frontend_v1.py @@ -12,6 +12,7 @@ from onnx import TensorProto from onnx_tf.common import as_dtype from onnx_tf.common import get_unique_suffix +from onnx_tf.common import get_perm_from_formats from onnx_tf.common import TF_TYPE_TO_ONNX_TYPE from onnx_tf.frontend import TensorflowFrontendBase @@ -224,6 +225,39 @@ def handle_reshape(cls, node, **kwargs): return helper.make_node( "Reshape", [node.inputs[0]], [node.name], shape=shape) + @classmethod + @register_onnx_op("SpaceToDepth") + def handle_space_to_depth(cls, node, **kwargs): + blocksize = node.attr["block_size"] + data_format = node.attr.get("data_format", "NHWC").decode() + + assert data_format in ["NHWC", "NCHW"], \ + ("data format {} should be in ['NCHW', 'NHWC'].".format(data_format)) + + if data_format == "NHWC": + transpose_unique_suffix = get_unique_suffix() + space_to_depth_unique_suffix = get_unique_suffix() + transpose_name = node.inputs[0] + "_T_" + transpose_unique_suffix + space_to_depth_name = node.inputs[0] + "_T_STD_" + space_to_depth_unique_suffix + before_transpose_node = helper.make_node( + "Transpose", + [node.inputs[0]], [transpose_name], + perm=get_perm_from_formats("NHWC", "NCHW")) + space_to_depth_node = helper.make_node( + "SpaceToDepth", + [transpose_name], [space_to_depth_name], blocksize=blocksize) + + after_transpose_node = helper.make_node( + "Transpose", + [space_to_depth_name], [node.name], + perm=get_perm_from_formats("NCHW", "NHWC")) + + return [before_transpose_node, space_to_depth_node, after_transpose_node] + + if data_format == "NCHW": + return helper.make_node( + "SpaceToDepth", [node.inputs[0]], [node.name], blocksize=blocksize) + @classmethod @register_onnx_op("Split") def handle_split_v(cls, node, **kwargs): diff --git a/onnx_tf/opset_version.py b/onnx_tf/opset_version.py index 82eaccd79..3024c7210 100644 --- a/onnx_tf/opset_version.py +++ b/onnx_tf/opset_version.py @@ -215,7 +215,7 @@ 'Softmax': [1], 'Softplus': [], 'Softsign': [], - 'SpaceToDepth': [], + 'SpaceToDepth': [1], 'Split': [1, 2], 'Sqrt': [1, 6], 'Squeeze': [1], @@ -277,6 +277,7 @@ 'shape': [1], 'sigmoid': [1, 6], 'softmax': [1], + 'space_to_depth': [1], 'split_v': [1, 2], 'sqrt': [1, 6], 'squeeze': [1], diff --git a/test/frontend/test_node.py b/test/frontend/test_node.py index f851d906e..7b4a83767 100644 --- a/test/frontend/test_node.py +++ b/test/frontend/test_node.py @@ -125,6 +125,7 @@ def do_test_expected(self): ("test_reshape", tf.reshape, "Reshape", [get_rnd([10, 10]), [4, 25]], {}), ("test_shape", tf.shape, "Shape", [get_rnd([1, 2, 3, 4])], {}), ("test_sigmoid", tf.sigmoid, "Sigmoid", [get_rnd([10, 10])], {}), +("test_space_to_depth", tf.space_to_depth, "SpaceToDepth", [get_rnd([2, 8, 8, 5])], {"block_size": 2}), ("test_split", tf.split, "split", [get_rnd([10, 10]), [2, 3, 5]], {}), ("test_sqrt", tf.sqrt, "Sqrt", [get_rnd([10, 10])], {}), ("test_squeeze", tf.squeeze, "Squeeze", [get_rnd([1, 1, 10, 10])], {"axis":[0, 1]}),