diff --git a/nemo/collections/nlp/__init__.py b/nemo/collections/nlp/__init__.py index f91479ad1614..e3b3dc41aac3 100644 --- a/nemo/collections/nlp/__init__.py +++ b/nemo/collections/nlp/__init__.py @@ -14,7 +14,5 @@ # limitations under the License. # ============================================================================= -import nemo from nemo.collections.nlp import callbacks, data, nm, utils - -backend = nemo.core.Backend.PyTorch +from nemo.collections.nlp.neural_types import * diff --git a/nemo/collections/nlp/neural_types.py b/nemo/collections/nlp/neural_types.py new file mode 100644 index 000000000000..2134de4ef20c --- /dev/null +++ b/nemo/collections/nlp/neural_types.py @@ -0,0 +1,73 @@ +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.core.neural_types import AxisKindAbstract, ElementType, StringType + +__all__ = [ + 'DialogAxisKind', + 'Utterance', + 'UserUtterance', + 'SystemUtterance', + 'AgentUtterance', + 'SlotValue', + 'MultiWOZBeliefState', +] + + +class DialogAxisKind(AxisKindAbstract): + """ Class containing definitions of axis kinds specialized for the dialog problem domain. """ + + Domain = 7 + + +class Utterance(StringType): + """Element type representing an utterance (e.g. "Is there a train from Ely to Cambridge on Tuesday ?").""" + + +class UserUtterance(Utterance): + """Element type representing a utterance expresesd by the user.""" + + +class SystemUtterance(Utterance): + """Element type representing an utterance produced by the system.""" + + +class AgentUtterance(ElementType): + """Element type representing utterance returned by an agent (user or system) participating in a dialog.""" + + # def __str__(self): + # return "Utterance returned by an agent (user or system) participating in a dialog." + + # def fields(self): + # return ("agent", "utterance") + + +class SlotValue(ElementType): + """Element type representing slot-value pair.""" + + # def __str__(self): + # return "Slot-value pair" + + # def fields(self): + # return ("slot", "value") + + +class MultiWOZBeliefState(SlotValue): + """Element type representing MultiWOZ belief state - one per domain.""" + + # def fields(self): + # return ("book", "semi") diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_dpm_multiwoz.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_dpm_multiwoz.py index 4c70711476e0..8e0b70563c16 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_dpm_multiwoz.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_dpm_multiwoz.py @@ -33,6 +33,7 @@ from nemo.collections.nlp.data.datasets.multiwoz_dataset.dbquery import Database from nemo.collections.nlp.data.datasets.multiwoz_dataset.multiwoz_slot_trans import REF_SYS_DA, REF_USR_DA from nemo.core.neural_types import * +from nemo.collections.nlp.neural_types import * from nemo.utils.decorators import add_port_docs __all__ = ['RuleBasedDPMMultiWOZ'] @@ -120,10 +121,10 @@ def input_ports(self): axes=[ AxisType(kind=AxisKind.Batch, is_list=True), AxisType( - kind=AxisKind.MultiWOZDomain, is_list=True + kind=DialogAxisKind.Domain, is_list=True ), # always 7 domains - but cannot set size with is_list! ], - elements_type=MultiWOZDomainState(), + elements_type=MultiWOZBeliefState(), ), 'request_state': NeuralType( axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)], @@ -143,10 +144,10 @@ def output_ports(self): axes=[ AxisType(kind=AxisKind.Batch, is_list=True), AxisType( - kind=AxisKind.MultiWOZDomain, is_list=True + kind=DialogAxisKind.Domain, is_list=True ), # always 7 domains - but cannot set size with is_list! ], - elements_type=MultiWOZDomainState(), + elements_type=MultiWOZBeliefState(), ), 'system_acts': NeuralType( axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)], diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/system_utterance_history_update.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/system_utterance_history_update.py index 6dfa0569a78f..bfb8e4299c6d 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/system_utterance_history_update.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/system_utterance_history_update.py @@ -21,6 +21,7 @@ from nemo.backends.pytorch.nm import NonTrainableNM from nemo.core.neural_types import * +from nemo.collections.nlp.neural_types import * from nemo.utils import logging from nemo.utils.decorators import add_port_docs diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/template_nlg_multiwoz.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/template_nlg_multiwoz.py index cbbf63f77ad5..92c11e53e142 100644 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/template_nlg_multiwoz.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/template_nlg_multiwoz.py @@ -25,6 +25,7 @@ from nemo.backends.pytorch.nm import NonTrainableNM from nemo.core.neural_types import * +from nemo.collections.nlp.neural_types import * from nemo.utils import logging from nemo.utils.decorators import add_port_docs diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py index 70eee4f9e296..9cb5e6ac1432 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py @@ -28,6 +28,7 @@ from nemo.collections.nlp.data.datasets.multiwoz_dataset.multiwoz_slot_trans import REF_SYS_DA from nemo.collections.nlp.utils.callback_utils import tensor2numpy from nemo.core.neural_types import * +from nemo.collections.nlp.neural_types import * from nemo.utils import logging from nemo.utils.decorators import add_port_docs @@ -52,9 +53,9 @@ def input_ports(self): 'belief_state': NeuralType( axes=[ AxisType(kind=AxisKind.Batch, is_list=True), - AxisType(kind=AxisKind.MultiWOZDomain, is_list=True), # 7 domains + AxisType(kind=DialogAxisKind.Domain, is_list=True), # 7 domains ], - elements_type=MultiWOZDomainState(), + elements_type=MultiWOZBeliefState(), ), 'user_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()), } @@ -68,9 +69,9 @@ def output_ports(self): 'belief_state': NeuralType( axes=[ AxisType(kind=AxisKind.Batch, is_list=True), - AxisType(kind=AxisKind.MultiWOZDomain, is_list=True), # 7 domains + AxisType(kind=DialogAxisKind.Domain, is_list=True), # 7 domains ], - elements_type=MultiWOZDomainState(), + elements_type=MultiWOZBeliefState(), ), 'request_state': NeuralType( axes=[AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Sequence, is_list=True)], diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/user_utterance_encoder.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/user_utterance_encoder.py index 563d13273528..2a85ac323c66 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/user_utterance_encoder.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/user_utterance_encoder.py @@ -22,6 +22,7 @@ from nemo.backends.pytorch.nm import NonTrainableNM from nemo.core.neural_types import * +from nemo.collections.nlp.neural_types import * from nemo.utils import logging from nemo.utils.decorators import add_port_docs diff --git a/nemo/core/neural_types/axes.py b/nemo/core/neural_types/axes.py index 92fec6dccf6b..cc8a1c29cf57 100644 --- a/nemo/core/neural_types/axes.py +++ b/nemo/core/neural_types/axes.py @@ -45,8 +45,6 @@ class AxisKind(AxisKindAbstract): Height = 4 Any = 5 Sequence = 6 - DialogDomain = 7 - MultiWOZDomain = 8 # A specialized DialogDomain axis? def __repr__(self): return self.__str__() diff --git a/nemo/core/neural_types/elements.py b/nemo/core/neural_types/elements.py index d6bd965331cd..f8a6c110c4f7 100644 --- a/nemo/core/neural_types/elements.py +++ b/nemo/core/neural_types/elements.py @@ -43,14 +43,8 @@ 'NormalizedImageValue', 'StringLabel', 'StringType', - 'Utterance', - 'UserUtterance', - 'SystemUtterance', - 'AgentUtterance', 'TokenIndex', 'Length', - 'SlotValue', - 'MultiWOZDomainState', ] import abc @@ -256,28 +250,6 @@ class StringLabel(StringType): """ -class Utterance(StringType): - """Element type representing an utterance (e.g. "Is there a train from Ely to Cambridge on Tuesday ?").""" - - -class UserUtterance(Utterance): - """Element type representing a utterance expresesd by the user.""" - - -class SystemUtterance(Utterance): - """Element type representing an utterance produced by the system.""" - - -class AgentUtterance(ElementType): - """Element type representing utterance returned by an agent (user or system) participating in a dialog.""" - - # def __str__(self): - # return "Utterance returned by an agent (user or system) participating in a dialog." - - # def fields(self): - # return ("agent", "utterance") - - class IntType(ElementType): """Element type representing a single integer""" @@ -288,20 +260,3 @@ class TokenIndex(IntType): class Length(IntType): """Type representing an element storing a "length" (e.g. length of a list).""" - - -class SlotValue(ElementType): - """Element type representing slot-value pair.""" - - # def __str__(self): - # return "Slot-value pair" - - # def fields(self): - # return ("slot", "value") - - -class MultiWOZDomainState(SlotValue): - """Element type representing MultiWOZ slot-value pair.""" - - # def fields(self): - # return ("book", "semi")