diff --git a/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py b/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py index cebd1d8e0f28..137196488e5b 100644 --- a/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py +++ b/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py @@ -41,6 +41,7 @@ from nemo.collections.nlp.data.datasets.multiwoz_dataset.state import init_state from nemo.collections.nlp.nm.non_trainables import ( RuleBasedDPMMultiWOZ, + SystemUtteranceHistoryUpdate, TemplateNLGMultiWOZ, TradeStateUpdateNM, UserUtteranceEncoder, @@ -62,7 +63,7 @@ ] -def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state): +def forward(dialog_pipeline, user_uttr, dial_history, belief_state): """ Forward pass of the "Complete Dialog Pipeline". @@ -75,7 +76,6 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state) Args: user_uttr (str): User utterance - system_uttr (str): Previous system utterance dialog_history (str): Dialogue history contains all previous system and user utterances belief_state (dict): dialogue state Returns: @@ -87,7 +87,7 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state) # 1. Forward pass throught Word-Level Dialog State Tracking modules (TRADE). # 1.1. User utterance encoder. dialog_ids, dialog_lens, dial_history = dialog_pipeline.modules[dialog_pipeline.steps[0]].forward( - user_uttr=user_uttr, sys_uttr=system_uttr, dialog_history=dial_history, + user_uttr=user_uttr, dialog_history=dial_history, ) # 1.2. TRADE encoder. outputs, hidden = dialog_pipeline.modules[dialog_pipeline.steps[1]].forward( @@ -111,6 +111,11 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state) # 3. Forward pass throught Natural Language Generator module (Template-Based). system_uttr = dialog_pipeline.modules[dialog_pipeline.steps[5]].forward(system_acts=system_acts) + # 4. Update dialog history with system utterance + dial_history = dialog_pipeline.modules[dialog_pipeline.steps[6]].forward( + sys_uttr=system_uttr, dialog_history=dial_history + ) + # Return the updated states and dialog history. return system_uttr, belief_state, dial_history @@ -199,13 +204,16 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state) # NLG module. template_nlg = TemplateNLGMultiWOZ() + # Updates dialog history with system utterance. + sys_utter_history_update = SystemUtteranceHistoryUpdate() + # Construct the "evaluation" (inference) neural graph by connecting the modules using nmTensors. # Note: Using the same names for passed nmTensor as in the actual forward pass. with NeuralGraph(operation_mode=OperationMode.evaluation) as dialog_pipeline: # 1.1. User utterance encoder. # Bind all the input ports of this module. dialog_ids, dialog_lens, dial_history = user_utterance_encoder( - user_uttr=dialog_pipeline, sys_uttr=dialog_pipeline, dialog_history=dialog_pipeline, + user_uttr=dialog_pipeline, dialog_history=dialog_pipeline, ) # Fire step 1: 1.2. TRADE encoder. outputs, hidden = trade_encoder(inputs=dialog_ids, input_lens=dialog_lens) @@ -229,6 +237,9 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state) # 3. Forward pass throught Natural Language Generator module (Template-Based). system_uttr = template_nlg(system_acts=system_acts) + # 4. Update dialog history with system utterance + dial_history = sys_utter_history_update(sys_uttr=system_uttr, dialog_history=dial_history) + # Show the graph summary. logging.info(dialog_pipeline.summary()) @@ -250,7 +261,7 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state) else: # Pass the "user uterance" as inputs to the dialog pipeline. system_uttr, belief_state, dial_history = forward( - dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state + dialog_pipeline, user_uttr, dial_history, belief_state ) elif args.mode == 'example': @@ -262,5 +273,5 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state) for user_uttr in example: logging.info("User utterance: %s", user_uttr) system_uttr, belief_state, dial_history = forward( - dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state + dialog_pipeline, user_uttr, dial_history, belief_state ) diff --git a/nemo/collections/nlp/nm/non_trainables/__init__.py b/nemo/collections/nlp/nm/non_trainables/__init__.py index 45af1a64aeb7..b635378d9398 100644 --- a/nemo/collections/nlp/nm/non_trainables/__init__.py +++ b/nemo/collections/nlp/nm/non_trainables/__init__.py @@ -15,6 +15,9 @@ # ============================================================================= from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.rule_based_dpm_multiwoz import RuleBasedDPMMultiWOZ +from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.system_utterance_history_update import ( + SystemUtteranceHistoryUpdate, +) from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.template_nlg_multiwoz import TemplateNLGMultiWOZ from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.trade_state_update_nm import TradeStateUpdateNM from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.user_utterance_encoder import UserUtteranceEncoder diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/system_utterance_history_update.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/system_utterance_history_update.py new file mode 100755 index 000000000000..f63ff378d20c --- /dev/null +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/system_utterance_history_update.py @@ -0,0 +1,88 @@ +# ============================================================================= +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +''' +This file contains code artifacts adapted from the original implementation: +https://github.com/thu-coai/ConvLab-2/ +''' +import torch + +from nemo.backends.pytorch.nm import NonTrainableNM +from nemo.core.neural_types import * +from nemo.utils import logging +from nemo.utils.decorators import add_port_docs + +__all__ = ['SystemUtteranceHistoryUpdate'] + + +class SystemUtteranceHistoryUpdate(NonTrainableNM): + """ + Updates dialogue history with system utterance. + """ + + @property + @add_port_docs() + def input_ports(self): + """Returns definitions of module input ports. + user_uttr (str): user utterance + sys_uttr (str): system utterace + dialog_history (list): dialogue history, list of system and diaglogue utterances + """ + return { + 'sys_uttr': NeuralType( + axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=SystemUtterance() + ), + 'dialog_history': NeuralType( + axes=(AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True),), + elements_type=AgentUtterance(), + ), + } + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + dialog_history (list): dialogue history, being a list of user and system utterances. + """ + return { + 'dialog_history': NeuralType( + axes=(AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True),), + elements_type=AgentUtterance(), + ), + } + + def __init__(self): + """ + Initializes the object + Args: + data_desc (obj): data descriptor for MultiWOZ dataset, contains information about domains, slots, + and associated vocabulary + """ + super().__init__() + + def forward(self, sys_uttr, dialog_history): + """ + Returns updated dialog history. + Args: + sys_uttr (str): system utterace + dialog_history (list): dialogue history, list of user and system diaglogue utterances + Returns: + dialog_history (list): updated dialogue history, list of user and system diaglogue utterances + """ + dialog_history.append(["sys", sys_uttr]) + logging.debug("Dialogue history: %s", dialog_history) + + return dialog_history diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/template_nlg_multiwoz.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/template_nlg_multiwoz.py index 23e49583afdf..cbbf63f77ad5 100644 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/template_nlg_multiwoz.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/template_nlg_multiwoz.py @@ -111,7 +111,9 @@ def output_ports(self): system_uttr (str): generated system's response """ return { - 'sys_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()), + 'sys_uttr': NeuralType( + axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=SystemUtterance() + ), } def forward(self, system_acts): diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/user_utterance_encoder.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/user_utterance_encoder.py index 13a14c420eaa..563d13273528 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/user_utterance_encoder.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/user_utterance_encoder.py @@ -30,7 +30,8 @@ class UserUtteranceEncoder(NonTrainableNM): """ - Encodes dialogue history (system and user utterances) into a Multiwoz dataset format + Updates dialogue history with user utterance and encodes the history (system and user utterances) into + a flat list of tokens (per sample). Args: data_desc (obj): data descriptor for MultiWOZ dataset, contains information about domains, slots, and associated vocabulary @@ -41,12 +42,10 @@ class UserUtteranceEncoder(NonTrainableNM): def input_ports(self): """Returns definitions of module input ports. user_uttr (str): user utterance - sys_uttr (str): system utterace dialog_history (list): dialogue history, list of system and diaglogue utterances """ return { - 'user_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()), - 'sys_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()), + 'user_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=UserUtterance()), 'dialog_history': NeuralType( axes=(AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True),), elements_type=AgentUtterance(), @@ -81,20 +80,18 @@ def __init__(self, data_desc): super().__init__() self.data_desc = data_desc - def forward(self, user_uttr, sys_uttr, dialog_history): + def forward(self, user_uttr, dialog_history): """ Returns dialogue utterances in the format accepted by the TRADE Dialogue state tracking model Args: dialog_history (list): dialogue history, list of system and diaglogue utterances user_uttr (str): user utterance - sys_uttr (str): system utterace Returns: dialog_ids (int): token ids for the whole dialogue history dialog_lens (int): length of the whole tokenized dialogue history dialog_history (list): updated dialogue history, list of system and diaglogue utterances """ # TODO: why we update sys utterance, whereas we have only user utterance at that point? - dialog_history.append(["sys", sys_uttr]) dialog_history.append(["user", user_uttr]) logging.debug("Dialogue history: %s", dialog_history) diff --git a/nemo/core/neural_types/elements.py b/nemo/core/neural_types/elements.py index 1d02b4e88a94..d6bd965331cd 100644 --- a/nemo/core/neural_types/elements.py +++ b/nemo/core/neural_types/elements.py @@ -44,6 +44,8 @@ 'StringLabel', 'StringType', 'Utterance', + 'UserUtterance', + 'SystemUtterance', 'AgentUtterance', 'TokenIndex', 'Length', @@ -258,14 +260,22 @@ class Utterance(StringType): """Element type representing an utterance (e.g. "Is there a train from Ely to Cambridge on Tuesday ?").""" +class UserUtterance(Utterance): + """Element type representing a utterance expresesd by the user.""" + + +class SystemUtterance(Utterance): + """Element type representing an utterance produced by the system.""" + + class AgentUtterance(ElementType): """Element type representing utterance returned by an agent (user or system) participating in a dialog.""" - def __str__(self): - return "Utterance returned by an agent (user or system) participating in a dialog." + # def __str__(self): + # return "Utterance returned by an agent (user or system) participating in a dialog." - def fields(self): - return ("agent", "utterance") + # def fields(self): + # return ("agent", "utterance") class IntType(ElementType):