Skip to content

Commit

Permalink
fixing some unittests
Browse files Browse the repository at this point in the history
Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>
  • Loading branch information
okuchaiev committed Jan 29, 2020
1 parent 842f710 commit 21e7f31
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 13 deletions.
2 changes: 1 addition & 1 deletion nemo/backends/pytorch/common/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn

from nemo.backends.pytorch.nm import LossNM
from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, RegressionTag, TimeTag
from nemo.core import AxisType, BatchTag, ChannelTag, NeuralType, RegressionTag, TimeTag

__all__ = ['SequenceLoss', 'CrossEntropyLoss', 'MSELoss']

Expand Down
2 changes: 1 addition & 1 deletion nemo/backends/pytorch/common/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from nemo.backends.pytorch.common.parts import Attention
from nemo.backends.pytorch.nm import TrainableNM
from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag
from nemo.core import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag
from nemo.utils.misc import pad_to


Expand Down
2 changes: 1 addition & 1 deletion nemo/backends/pytorch/common/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from nemo.backends.pytorch.nm import NonTrainableNM
from nemo.core.neural_types import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag
from nemo.core import AxisType, BatchTag, ChannelTag, NeuralType, TimeTag

INF = float('inf')
BIG_NUM = 1e4
Expand Down
20 changes: 18 additions & 2 deletions nemo/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
# Copyright (c) 2019 NVIDIA Corporation
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .callbacks import *
from .neural_factory import *
from .neural_modules import *
from .neural_types import *
from .old_neural_types import *
18 changes: 17 additions & 1 deletion nemo/core/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
# Copyright (c) 2019 NVIDIA Corporation
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os
import sys
Expand Down
18 changes: 17 additions & 1 deletion nemo/core/neural_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
# Copyright (c) 2019 NVIDIA Corporation
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = [
'Backend',
'ModelMode',
Expand Down
18 changes: 17 additions & 1 deletion nemo/core/neural_modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
# Copyright (c) 2019 NVIDIA Corporation
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains NeuralModule and NmTensor classes."""
__all__ = ['WeightShareTransform', 'NeuralModule']

Expand Down
4 changes: 2 additions & 2 deletions nemo/core/neural_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .elements import *
from .axes import *
from .comparison import *
from .neural_type import *
from .elements import *
from .neural_type import *
53 changes: 51 additions & 2 deletions nemo/core/neural_types/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ['ElementType', 'VoidType']
__all__ = [
'ElementType',
'VoidType',
'ChannelType',
'AcousticEncodedRepresentation',
'AudioSignal',
'SpectrogramType',
'MelSpectrogramType',
'MFCCSpectrogramType',
]
import abc
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict
from typing import Dict, Optional, Tuple

from .comparison import NeuralTypeComparisonResult


Expand Down Expand Up @@ -81,3 +91,42 @@ def __str__(self):

def compare(cls, second: abc.ABCMeta) -> NeuralTypeComparisonResult:
return NeuralTypeComparisonResult.SAME


# TODO: Consider moving these files elsewhere
class ChannelType(ElementType):
def __str__(self):
return "convolutional channel value"


class AcousticEncodedRepresentation(ChannelType):
def __str__(self):
return "encoded representation returned by the acoustic encoder model"


class AudioSignal(ElementType):
def __str__(self):
return "encoded representation returned by the acoustic encoder model"

def __init__(self, freq=16000):
self._params = {}
self._params['freq'] = freq

@property
def type_parameters(self):
return self._params


class SpectrogramType(ChannelType):
def __str__(self):
return "generic spectorgram type"


class MelSpectrogramType(SpectrogramType):
def __str__(self):
return "mel spectorgram type"


class MFCCSpectrogramType(SpectrogramType):
def __str__(self):
return "mfcc spectorgram type"
3 changes: 2 additions & 1 deletion nemo/core/neural_types/neural_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
]
import uuid
from typing import Tuple

from .axes import AxisKind, AxisType
from .comparison import NeuralTypeComparisonResult
from .axes import AxisType, AxisKind
from .elements import *


Expand Down
Empty file added tests/core/__init__.py
Empty file.
92 changes: 92 additions & 0 deletions tests/core/test_neural_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright 2019 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from nemo.core.neural_types import (
AcousticEncodedRepresentation,
AudioSignal,
AxisKind,
AxisType,
ChannelType,
MelSpectrogramType,
MFCCSpectrogramType,
NeuralType,
NeuralTypeComparisonResult,
SpectrogramType,
)
from tests.common_setup import NeMoUnitTest


class NeuralTypeSystemTests(NeMoUnitTest):
def test_short_vs_long_version(self):
long_version = NeuralType(
elements_type=AcousticEncodedRepresentation(),
axes=(AxisType(AxisKind.Batch, None), AxisType(AxisKind.Dimension, None), AxisType(AxisKind.Time, None)),
)
short_version = NeuralType(AcousticEncodedRepresentation(), ('B', 'D', 'T'))
self.assertEqual(long_version.compare(short_version), NeuralTypeComparisonResult.SAME)
self.assertEqual(short_version.compare(long_version), NeuralTypeComparisonResult.SAME)

def test_parameterized_type_audio_sampling_frequency(self):
audio16K = NeuralType(AudioSignal(16000), axes=('B', 'T'))
audio8K = NeuralType(AudioSignal(8000), axes=('B', 'T'))
another16K = NeuralType(AudioSignal(16000), axes=('B', 'T'))

self.assertEqual(audio8K.compare(audio16K), NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS)
self.assertEqual(audio16K.compare(audio8K), NeuralTypeComparisonResult.SAME_TYPE_INCOMPATIBLE_PARAMS)
self.assertEqual(another16K.compare(audio16K), NeuralTypeComparisonResult.SAME)
self.assertEqual(audio16K.compare(another16K), NeuralTypeComparisonResult.SAME)

def test_transpose_same(self):
audio16K = NeuralType(AudioSignal(16000), axes=('B', 'T'))
audio16K_t = NeuralType(AudioSignal(16000), axes=('T', 'B'))
self.assertEqual(audio16K.compare(audio16K_t), NeuralTypeComparisonResult.TRANSPOSE_SAME)

def test_inheritance_spec_augment_example(self):
input = NeuralType(SpectrogramType(), ('B', 'D', 'T'))
out1 = NeuralType(MelSpectrogramType(), ('B', 'D', 'T'))
out2 = NeuralType(MFCCSpectrogramType(), ('B', 'D', 'T'))
self.assertEqual(out1.compare(out2), NeuralTypeComparisonResult.INCOMPATIBLE)
self.assertEqual(out2.compare(out1), NeuralTypeComparisonResult.INCOMPATIBLE)
self.assertEqual(input.compare(out1), NeuralTypeComparisonResult.GREATER)
self.assertEqual(input.compare(out2), NeuralTypeComparisonResult.GREATER)
self.assertEqual(out1.compare(input), NeuralTypeComparisonResult.LESS)
self.assertEqual(out2.compare(input), NeuralTypeComparisonResult.LESS)

def test_list_of_lists(self):
T1 = NeuralType(
elements_type=ChannelType(),
axes=(
AxisType(kind=AxisKind.Batch, size=None, is_list=True),
AxisType(kind=AxisKind.Time, size=None, is_list=True),
AxisType(kind=AxisKind.Dimension, size=32, is_list=False),
AxisType(kind=AxisKind.Dimension, size=128, is_list=False),
AxisType(kind=AxisKind.Dimension, size=256, is_list=False),
),
)
T2 = NeuralType(
elements_type=ChannelType(),
axes=(
AxisType(kind=AxisKind.Batch, size=None, is_list=False),
AxisType(kind=AxisKind.Time, size=None, is_list=False),
AxisType(kind=AxisKind.Dimension, size=32, is_list=False),
AxisType(kind=AxisKind.Dimension, size=128, is_list=False),
AxisType(kind=AxisKind.Dimension, size=256, is_list=False),
),
)
# TODO: should this be incompatible instead???
self.assertEqual(T1.compare(T2), NeuralTypeComparisonResult.TRANSPOSE_SAME)

0 comments on commit 21e7f31

Please sign in to comment.