From 21e7f319c69c82975be6c13102eaf2a2ad60d6e0 Mon Sep 17 00:00:00 2001 From: Oleksii Kuchaiev Date: Tue, 28 Jan 2020 17:01:23 -0800 Subject: [PATCH] fixing some unittests Signed-off-by: Oleksii Kuchaiev --- nemo/backends/pytorch/common/losses.py | 2 +- nemo/backends/pytorch/common/rnn.py | 2 +- nemo/backends/pytorch/common/search.py | 2 +- nemo/core/__init__.py | 20 +++++- nemo/core/callbacks.py | 18 ++++- nemo/core/neural_factory.py | 18 ++++- nemo/core/neural_modules.py | 18 ++++- nemo/core/neural_types/__init__.py | 4 +- nemo/core/neural_types/elements.py | 53 ++++++++++++++- nemo/core/neural_types/neural_type.py | 3 +- tests/core/__init__.py | 0 tests/core/test_neural_types.py | 92 ++++++++++++++++++++++++++ 12 files changed, 219 insertions(+), 13 deletions(-) create mode 100644 tests/core/__init__.py create mode 100644 tests/core/test_neural_types.py diff --git a/nemo/backends/pytorch/common/losses.py b/nemo/backends/pytorch/common/losses.py index 295c09ba1ce4..90a20a633c81 100644 --- a/nemo/backends/pytorch/common/losses.py +++ b/nemo/backends/pytorch/common/losses.py @@ -2,7 +2,7 @@ from torch import nn from nemo.backends.pytorch.nm import LossNM -from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, RegressionTag, TimeTag +from nemo.core import AxisType, BatchTag, ChannelTag, NeuralType, RegressionTag, TimeTag __all__ = ['SequenceLoss', 'CrossEntropyLoss', 'MSELoss'] diff --git a/nemo/backends/pytorch/common/rnn.py b/nemo/backends/pytorch/common/rnn.py index c7f6fc66f5bc..c171ad7e00fd 100644 --- a/nemo/backends/pytorch/common/rnn.py +++ b/nemo/backends/pytorch/common/rnn.py @@ -8,7 +8,7 @@ from nemo.backends.pytorch.common.parts import Attention from nemo.backends.pytorch.nm import TrainableNM -from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag +from nemo.core import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag from nemo.utils.misc import pad_to diff --git a/nemo/backends/pytorch/common/search.py b/nemo/backends/pytorch/common/search.py index 812c22ce2cfd..7b449acdd0d3 100644 --- a/nemo/backends/pytorch/common/search.py +++ b/nemo/backends/pytorch/common/search.py @@ -3,7 +3,7 @@ import torch from nemo.backends.pytorch.nm import NonTrainableNM -from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag +from nemo.core import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag INF = float('inf') BIG_NUM = 1e4 diff --git a/nemo/core/__init__.py b/nemo/core/__init__.py index 7b13691e476a..06a0050f1b7e 100644 --- a/nemo/core/__init__.py +++ b/nemo/core/__init__.py @@ -1,5 +1,21 @@ -# Copyright (c) 2019 NVIDIA Corporation +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .callbacks import * from .neural_factory import * from .neural_modules import * -from .neural_types import * +from .old_neural_types import * diff --git a/nemo/core/callbacks.py b/nemo/core/callbacks.py index 4f6c94ba01dc..1ebf3675e270 100644 --- a/nemo/core/callbacks.py +++ b/nemo/core/callbacks.py @@ -1,4 +1,20 @@ -# Copyright (c) 2019 NVIDIA Corporation +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import glob import os import sys diff --git a/nemo/core/neural_factory.py b/nemo/core/neural_factory.py index 086af2a04fbf..9f61c086b58e 100644 --- a/nemo/core/neural_factory.py +++ b/nemo/core/neural_factory.py @@ -1,4 +1,20 @@ -# Copyright (c) 2019 NVIDIA Corporation +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + __all__ = [ 'Backend', 'ModelMode', diff --git a/nemo/core/neural_modules.py b/nemo/core/neural_modules.py index 663bb3da3184..373839ee93b2 100644 --- a/nemo/core/neural_modules.py +++ b/nemo/core/neural_modules.py @@ -1,4 +1,20 @@ -# Copyright (c) 2019 NVIDIA Corporation +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """This file contains NeuralModule and NmTensor classes.""" __all__ = ['WeightShareTransform', 'NeuralModule'] diff --git a/nemo/core/neural_types/__init__.py b/nemo/core/neural_types/__init__.py index 92c9b37c32b6..124adc132c72 100644 --- a/nemo/core/neural_types/__init__.py +++ b/nemo/core/neural_types/__init__.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .elements import * from .axes import * from .comparison import * -from .neural_type import * \ No newline at end of file +from .elements import * +from .neural_type import * diff --git a/nemo/core/neural_types/elements.py b/nemo/core/neural_types/elements.py index 0b3626556b96..b806280677f5 100644 --- a/nemo/core/neural_types/elements.py +++ b/nemo/core/neural_types/elements.py @@ -15,10 +15,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['ElementType', 'VoidType'] +__all__ = [ + 'ElementType', + 'VoidType', + 'ChannelType', + 'AcousticEncodedRepresentation', + 'AudioSignal', + 'SpectrogramType', + 'MelSpectrogramType', + 'MFCCSpectrogramType', +] import abc from abc import ABC, abstractmethod -from typing import Tuple, Optional, Dict +from typing import Dict, Optional, Tuple + from .comparison import NeuralTypeComparisonResult @@ -81,3 +91,42 @@ def __str__(self): def compare(cls, second: abc.ABCMeta) -> NeuralTypeComparisonResult: return NeuralTypeComparisonResult.SAME + + +# TODO: Consider moving these files elsewhere +class ChannelType(ElementType): + def __str__(self): + return "convolutional channel value" + + +class AcousticEncodedRepresentation(ChannelType): + def __str__(self): + return "encoded representation returned by the acoustic encoder model" + + +class AudioSignal(ElementType): + def __str__(self): + return "encoded representation returned by the acoustic encoder model" + + def __init__(self, freq=16000): + self._params = {} + self._params['freq'] = freq + + @property + def type_parameters(self): + return self._params + + +class SpectrogramType(ChannelType): + def __str__(self): + return "generic spectorgram type" + + +class MelSpectrogramType(SpectrogramType): + def __str__(self): + return "mel spectorgram type" + + +class MFCCSpectrogramType(SpectrogramType): + def __str__(self): + return "mfcc spectorgram type" diff --git a/nemo/core/neural_types/neural_type.py b/nemo/core/neural_types/neural_type.py index cbb216ef80ef..346668e7d303 100644 --- a/nemo/core/neural_types/neural_type.py +++ b/nemo/core/neural_types/neural_type.py @@ -26,8 +26,9 @@ ] import uuid from typing import Tuple + +from .axes import AxisKind, AxisType from .comparison import NeuralTypeComparisonResult -from .axes import AxisType, AxisKind from .elements import * diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/core/test_neural_types.py b/tests/core/test_neural_types.py new file mode 100644 index 000000000000..bffdf705bd56 --- /dev/null +++ b/tests/core/test_neural_types.py @@ -0,0 +1,92 @@ +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright 2019 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.core.neural_types import ( + AcousticEncodedRepresentation, + AudioSignal, + AxisKind, + AxisType, + ChannelType, + MelSpectrogramType, + MFCCSpectrogramType, + NeuralType, + NeuralTypeComparisonResult, + SpectrogramType, +) +from tests.common_setup import NeMoUnitTest + + +class NeuralTypeSystemTests(NeMoUnitTest): + def test_short_vs_long_version(self): + long_version = NeuralType( + elements_type=AcousticEncodedRepresentation(), + axes=(AxisType(AxisKind.Batch, None), AxisType(AxisKind.Dimension, None), AxisType(AxisKind.Time, None)), + ) + short_version = NeuralType(AcousticEncodedRepresentation(), ('B', 'D', 'T')) + self.assertEqual(long_version.compare(short_version), NeuralTypeComparisonResult.SAME) + self.assertEqual(short_version.compare(long_version), NeuralTypeComparisonResult.SAME) + + def test_parameterized_type_audio_sampling_frequency(self): + audio16K = NeuralType(AudioSignal(16000), axes=('B', 'T')) + audio8K = NeuralType(AudioSignal(8000), axes=('B', 'T')) + another16K = NeuralType(AudioSignal(16000), axes=('B', 'T')) + + self.assertEqual(audio8K.compare(audio16K), NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS) + self.assertEqual(audio16K.compare(audio8K), NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS) + self.assertEqual(another16K.compare(audio16K), NeuralTypeComparisonResult.SAME) + self.assertEqual(audio16K.compare(another16K), NeuralTypeComparisonResult.SAME) + + def test_transpose_same(self): + audio16K = NeuralType(AudioSignal(16000), axes=('B', 'T')) + audio16K_t = NeuralType(AudioSignal(16000), axes=('T', 'B')) + self.assertEqual(audio16K.compare(audio16K_t), NeuralTypeComparisonResult.TRANSPOSE_SAME) + + def test_inheritance_spec_augment_example(self): + input = NeuralType(SpectrogramType(), ('B', 'D', 'T')) + out1 = NeuralType(MelSpectrogramType(), ('B', 'D', 'T')) + out2 = NeuralType(MFCCSpectrogramType(), ('B', 'D', 'T')) + self.assertEqual(out1.compare(out2), NeuralTypeComparisonResult.INCOMPATIBLE) + self.assertEqual(out2.compare(out1), NeuralTypeComparisonResult.INCOMPATIBLE) + self.assertEqual(input.compare(out1), NeuralTypeComparisonResult.GREATER) + self.assertEqual(input.compare(out2), NeuralTypeComparisonResult.GREATER) + self.assertEqual(out1.compare(input), NeuralTypeComparisonResult.LESS) + self.assertEqual(out2.compare(input), NeuralTypeComparisonResult.LESS) + + def test_list_of_lists(self): + T1 = NeuralType( + elements_type=ChannelType(), + axes=( + AxisType(kind=AxisKind.Batch, size=None, is_list=True), + AxisType(kind=AxisKind.Time, size=None, is_list=True), + AxisType(kind=AxisKind.Dimension, size=32, is_list=False), + AxisType(kind=AxisKind.Dimension, size=128, is_list=False), + AxisType(kind=AxisKind.Dimension, size=256, is_list=False), + ), + ) + T2 = NeuralType( + elements_type=ChannelType(), + axes=( + AxisType(kind=AxisKind.Batch, size=None, is_list=False), + AxisType(kind=AxisKind.Time, size=None, is_list=False), + AxisType(kind=AxisKind.Dimension, size=32, is_list=False), + AxisType(kind=AxisKind.Dimension, size=128, is_list=False), + AxisType(kind=AxisKind.Dimension, size=256, is_list=False), + ), + ) + # TODO: should this be incompatible instead??? + self.assertEqual(T1.compare(T2), NeuralTypeComparisonResult.TRANSPOSE_SAME)