From 02ddf841c9b68b24918c88bd30eb7d22dbde7f34 Mon Sep 17 00:00:00 2001 From: DawerG Date: Tue, 7 Jul 2020 12:36:43 -0700 Subject: [PATCH] Adding leaky_relu and upsample_nearest2d for yolov5 (#758) * ultralytics Yolov5 Passes * Added unit tests for newly added layers Co-authored-by: Gitesh Dawer --- .../converters/mil/frontend/torch/__init__.py | 2 +- .../converters/mil/frontend/torch/ops.py | 70 +++++++++++++++---- .../mil/frontend/torch/test/test_numerical.py | 26 +++++++ .../mil/frontend/torch/test/test_ops.py | 10 ++- 4 files changed, 91 insertions(+), 17 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/__init__.py b/coremltools/converters/mil/frontend/torch/__init__.py index d83660848..5991cc2ad 100644 --- a/coremltools/converters/mil/frontend/torch/__init__.py +++ b/coremltools/converters/mil/frontend/torch/__init__.py @@ -3,7 +3,7 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -from ....._deps import _HAS_TORCH +from coremltools._deps import _HAS_TORCH register_torch_op = None diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 483002eb7..1e19c7cd4 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -462,6 +462,12 @@ def relu(context, node): res = mb.relu(x=inputs[0], name=node.name) context.add(res) +@register_torch_op(torch_alias=["leaky_relu_"]) +def leaky_relu(context, node): + inputs = _get_inputs(context, node, expected=2) + + res = mb.leaky_relu(x=inputs[0], alpha=inputs[1], name=node.name) + context.add(res) def _adjust_pad_for_ceil_mode(input_shape, kernel_sizes, stride_sizes, pad_sizes): """ TODO Given an input tensor and pooling parameters, add the extra input @@ -1101,19 +1107,8 @@ def lstm(context, node): else: context.add(output, name) - -@register_torch_op -def upsample_bilinear2d(context, node): - inputs = _get_inputs(context, node) - _input = inputs[0] - output_size = inputs[1] - align_corners = bool(inputs[2].val) - - if len(inputs) == 5: - # For torch==1.5.0, upsample_bilinear2d has 5 inputs. - scales_h = inputs[3] - scales_w = inputs[4] - +def _get_scales_from_output_size(output_size, input_shape): + scales = [] if output_size is not None: # @output_size will be a list if scales was provided or a # single var if output size was provided @@ -1131,8 +1126,26 @@ def upsample_bilinear2d(context, node): # e.g. if output size = 34 and input size = 2, then scale will be # 17, which can get represented as 16.9999, resulting in an output size of 33 # instead of 34, without this correction. - scales_h = (output_size[0] + 1e-4) / float(_input.shape[-2]) - scales_w = (output_size[1] + 1e-4) / float(_input.shape[-1]) + scales_h = (output_size[0] + 1e-4) / float(input_shape[-2]) + scales_w = (output_size[1] + 1e-4) / float(input_shape[-1]) + scales = [scales_h, scales_w] + return scales + +@register_torch_op +def upsample_bilinear2d(context, node): + inputs = _get_inputs(context, node) + _input = inputs[0] + output_size = inputs[1] + align_corners = bool(inputs[2].val) + + if len(inputs) == 5: + # For torch==1.5.0, upsample_bilinear2d has 5 inputs. + scales_h = inputs[3] + scales_w = inputs[4] + + scales = _get_scales_from_output_size(output_size, _input.shape) + if scales: + scales_h, scales_w = scales upsample_bilinear = mb.upsample_bilinear( x=_input, @@ -1143,6 +1156,33 @@ def upsample_bilinear2d(context, node): ) context.add(upsample_bilinear) +@register_torch_op +def upsample_nearest2d(context, node): + inputs = _get_inputs(context, node) + _input = inputs[0] + output_size = inputs[1] + if len(inputs) == 4: + scales_h = inputs[2] + scales_w = inputs[3] + + scales = _get_scales_from_output_size(output_size, _input.shape) + if scales: + scales_h, scales_w = scales + + if ( + abs(scales_h - round(scales_h)) > 0.001 + or abs(scales_w - round(scales_w)) > 0.001 + ): + raise ValueError("Layer upsample_nearest2d only supports integral scales. Provided scales: {}. " + "Please use upsample_bilinear2d for fractional scales".format(scales)) + + upsample_nearest2d = mb.upsample_nearest_neighbor( + x=_input, + upscale_factor_height=int(round(scales_h)), + upscale_factor_width=int(round(scales_w)), + name=node.name, + ) + context.add(upsample_nearest2d) @register_torch_op(torch_alias=["listunpack"]) def tupleunpack(context, node): diff --git a/coremltools/converters/mil/frontend/torch/test/test_numerical.py b/coremltools/converters/mil/frontend/torch/test/test_numerical.py index c3651328c..a3fbd3bf0 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_numerical.py +++ b/coremltools/converters/mil/frontend/torch/test/test_numerical.py @@ -206,6 +206,17 @@ def test_upsample_bilinear2d_with_output_size(self, output_size, align_corners): ) run_numerical_test(input_shape, model) + @pytest.mark.parametrize( + "output_size", [(10, 10), (20, 30), (20, 20), (30, 20), (190, 170)] + ) + def test_upsample_nearest2d_with_output_size(self, output_size): + input_shape = (1, 3, 10, 10) + model = ModuleWrapper( + nn.functional.interpolate, + {"size": output_size, "mode": "nearest"}, + ) + run_numerical_test(input_shape, model) + @pytest.mark.parametrize( "scales_h, scales_w, align_corners", [x for x in itertools.product([2, 3, 4.5], [4, 5, 5.5], [True, False])], @@ -222,6 +233,21 @@ def test_upsample_bilinear2d_with_scales(self, scales_h, scales_w, align_corners ) run_numerical_test(input_shape, model) + @pytest.mark.parametrize( + "scales_h, scales_w", + [x for x in itertools.product([2, 3, 5], [4, 5, 2])], + ) + def test_upsample_nearest2d_with_scales(self, scales_h, scales_w): + input_shape = (1, 3, 10, 10) + model = ModuleWrapper( + nn.functional.interpolate, + { + "scale_factor": (scales_h, scales_w), + "mode": "nearest", + }, + ) + run_numerical_test(input_shape, model) + @pytest.mark.parametrize( "input_shape, eps", itertools.product([(1, 3, 15, 15), (1, 1, 1, 1)], [1e-5, 1e-9]), diff --git a/coremltools/converters/mil/frontend/torch/test/test_ops.py b/coremltools/converters/mil/frontend/torch/test/test_ops.py index ed86f7757..2271c3114 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_ops.py @@ -7,7 +7,9 @@ import numpy as np import pytest -import torch + +torch = pytest.importorskip("torch") + import torch.nn as nn import torch.nn.functional as F @@ -1419,6 +1421,12 @@ def test_relu(self, context): context, (3, 4, 5), [], "relu", ops.relu, nn.ReLU().eval(), atol=1e-6 ) + @pytest.mark.parametrize("alpha", [0.1, 2.0, 1.5]) + def test_leaky_relu(self, context, alpha): + self._test_activation( + context, (3, 4, 5), [alpha], "leaky_relu", ops.leaky_relu, nn.LeakyReLU(negative_slope=alpha).eval(), atol=1e-6 + ) + @pytest.mark.parametrize("dim", [0, 1, 2]) def test_log_softmax(self, context, dim): self._test_activation(