Skip to content

Commit

Permalink
moved dialog specific axesc andctypes to nlp/neural_types.py, refacto…
Browse files Browse the repository at this point in the history
…red the modules

Signed-off-by: nvidia <tkornuta@nvidia.com>
  • Loading branch information
tkornuta-nvidia committed Jun 5, 2020
1 parent 3bd914e commit 28b84b3
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 58 deletions.
4 changes: 1 addition & 3 deletions nemo/collections/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,5 @@
# limitations under the License.
# =============================================================================

import nemo
from nemo.collections.nlp import callbacks, data, nm, utils

backend = nemo.core.Backend.PyTorch
from nemo.collections.nlp.neural_types import *
73 changes: 73 additions & 0 deletions nemo/collections/nlp/neural_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.core.neural_types import AxisKindAbstract, ElementType, StringType

__all__ = [
'DialogAxisKind',
'Utterance',
'UserUtterance',
'SystemUtterance',
'AgentUtterance',
'SlotValue',
'MultiWOZBeliefState',
]


class DialogAxisKind(AxisKindAbstract):
""" Class containing definitions of axis kinds specialized for the dialog problem domain. """

Domain = 7


class Utterance(StringType):
"""Element type representing an utterance (e.g. "Is there a train from Ely to Cambridge on Tuesday ?")."""


class UserUtterance(Utterance):
"""Element type representing a utterance expresesd by the user."""


class SystemUtterance(Utterance):
"""Element type representing an utterance produced by the system."""


class AgentUtterance(ElementType):
"""Element type representing utterance returned by an agent (user or system) participating in a dialog."""

# def __str__(self):
# return "Utterance returned by an agent (user or system) participating in a dialog."

# def fields(self):
# return ("agent", "utterance")


class SlotValue(ElementType):
"""Element type representing slot-value pair."""

# def __str__(self):
# return "Slot-value pair"

# def fields(self):
# return ("slot", "value")


class MultiWOZBeliefState(SlotValue):
"""Element type representing MultiWOZ belief state - one per domain."""

# def fields(self):
# return ("book", "semi")
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from nemo.collections.nlp.data.datasets.multiwoz_dataset.dbquery import Database
from nemo.collections.nlp.data.datasets.multiwoz_dataset.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA
from nemo.core.neural_types import *
from nemo.collections.nlp.neural_types import *
from nemo.utils.decorators import add_port_docs

__all__ = ['RuleBasedDPMMultiWOZ']
Expand Down Expand Up @@ -120,10 +121,10 @@ def input_ports(self):
axes=[
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(
kind=AxisKind.MultiWOZDomain, is_list=True
kind=DialogAxisKind.Domain, is_list=True
), # always 7 domains - but cannot set size with is_list!
],
elements_type=MultiWOZDomainState(),
elements_type=MultiWOZBeliefState(),
),
'request_state': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)],
Expand All @@ -143,10 +144,10 @@ def output_ports(self):
axes=[
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(
kind=AxisKind.MultiWOZDomain, is_list=True
kind=DialogAxisKind.Domain, is_list=True
), # always 7 domains - but cannot set size with is_list!
],
elements_type=MultiWOZDomainState(),
elements_type=MultiWOZBeliefState(),
),
'system_acts': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from nemo.backends.pytorch.nm import NonTrainableNM
from nemo.core.neural_types import *
from nemo.collections.nlp.neural_types import *
from nemo.utils import logging
from nemo.utils.decorators import add_port_docs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from nemo.backends.pytorch.nm import NonTrainableNM
from nemo.core.neural_types import *
from nemo.collections.nlp.neural_types import *
from nemo.utils import logging
from nemo.utils.decorators import add_port_docs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nemo.collections.nlp.data.datasets.multiwoz_dataset.multiwoz_slot_trans import REF_SYS_DA
from nemo.collections.nlp.utils.callback_utils import tensor2numpy
from nemo.core.neural_types import *
from nemo.collections.nlp.neural_types import *
from nemo.utils import logging
from nemo.utils.decorators import add_port_docs

Expand All @@ -52,9 +53,9 @@ def input_ports(self):
'belief_state': NeuralType(
axes=[
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(kind=AxisKind.MultiWOZDomain, is_list=True), # 7 domains
AxisType(kind=DialogAxisKind.Domain, is_list=True), # 7 domains
],
elements_type=MultiWOZDomainState(),
elements_type=MultiWOZBeliefState(),
),
'user_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()),
}
Expand All @@ -68,9 +69,9 @@ def output_ports(self):
'belief_state': NeuralType(
axes=[
AxisType(kind=AxisKind.Batch, is_list=True),
AxisType(kind=AxisKind.MultiWOZDomain, is_list=True), # 7 domains
AxisType(kind=DialogAxisKind.Domain, is_list=True), # 7 domains
],
elements_type=MultiWOZDomainState(),
elements_type=MultiWOZBeliefState(),
),
'request_state': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from nemo.backends.pytorch.nm import NonTrainableNM
from nemo.core.neural_types import *
from nemo.collections.nlp.neural_types import *
from nemo.utils import logging
from nemo.utils.decorators import add_port_docs

Expand Down
2 changes: 0 additions & 2 deletions nemo/core/neural_types/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ class AxisKind(AxisKindAbstract):
Height = 4
Any = 5
Sequence = 6
DialogDomain = 7
MultiWOZDomain = 8 # A specialized DialogDomain axis?

def __repr__(self):
return self.__str__()
Expand Down
45 changes: 0 additions & 45 deletions nemo/core/neural_types/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,8 @@
'NormalizedImageValue',
'StringLabel',
'StringType',
'Utterance',
'UserUtterance',
'SystemUtterance',
'AgentUtterance',
'TokenIndex',
'Length',
'SlotValue',
'MultiWOZDomainState',
]

import abc
Expand Down Expand Up @@ -256,28 +250,6 @@ class StringLabel(StringType):
"""


class Utterance(StringType):
"""Element type representing an utterance (e.g. "Is there a train from Ely to Cambridge on Tuesday ?")."""


class UserUtterance(Utterance):
"""Element type representing a utterance expresesd by the user."""


class SystemUtterance(Utterance):
"""Element type representing an utterance produced by the system."""


class AgentUtterance(ElementType):
"""Element type representing utterance returned by an agent (user or system) participating in a dialog."""

# def __str__(self):
# return "Utterance returned by an agent (user or system) participating in a dialog."

# def fields(self):
# return ("agent", "utterance")


class IntType(ElementType):
"""Element type representing a single integer"""

Expand All @@ -288,20 +260,3 @@ class TokenIndex(IntType):

class Length(IntType):
"""Type representing an element storing a "length" (e.g. length of a list)."""


class SlotValue(ElementType):
"""Element type representing slot-value pair."""

# def __str__(self):
# return "Slot-value pair"

# def fields(self):
# return ("slot", "value")


class MultiWOZDomainState(SlotValue):
"""Element type representing MultiWOZ slot-value pair."""

# def fields(self):
# return ("book", "semi")

0 comments on commit 28b84b3

Please sign in to comment.