Skip to content

Commit

Permalink
Added module responsible for sys uttr dialog history update
Browse files Browse the repository at this point in the history
Signed-off-by: nvidia <tkornuta@nvidia.com>
  • Loading branch information
tkornuta-nvidia committed Jun 5, 2020
1 parent f9f1d45 commit 9d97337
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 18 deletions.
23 changes: 17 additions & 6 deletions examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from nemo.collections.nlp.data.datasets.multiwoz_dataset.state import init_state
from nemo.collections.nlp.nm.non_trainables import (
RuleBasedDPMMultiWOZ,
SystemUtteranceHistoryUpdate,
TemplateNLGMultiWOZ,
TradeStateUpdateNM,
UserUtteranceEncoder,
Expand All @@ -62,7 +63,7 @@
]


def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state):
def forward(dialog_pipeline, user_uttr, dial_history, belief_state):
"""
Forward pass of the "Complete Dialog Pipeline".
Expand All @@ -75,7 +76,6 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
Args:
user_uttr (str): User utterance
system_uttr (str): Previous system utterance
dialog_history (str): Dialogue history contains all previous system and user utterances
belief_state (dict): dialogue state
Returns:
Expand All @@ -87,7 +87,7 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
# 1. Forward pass throught Word-Level Dialog State Tracking modules (TRADE).
# 1.1. User utterance encoder.
dialog_ids, dialog_lens, dial_history = dialog_pipeline.modules[dialog_pipeline.steps[0]].forward(
user_uttr=user_uttr, sys_uttr=system_uttr, dialog_history=dial_history,
user_uttr=user_uttr, dialog_history=dial_history,
)
# 1.2. TRADE encoder.
outputs, hidden = dialog_pipeline.modules[dialog_pipeline.steps[1]].forward(
Expand All @@ -111,6 +111,11 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
# 3. Forward pass throught Natural Language Generator module (Template-Based).
system_uttr = dialog_pipeline.modules[dialog_pipeline.steps[5]].forward(system_acts=system_acts)

# 4. Update dialog history with system utterance
dial_history = dialog_pipeline.modules[dialog_pipeline.steps[6]].forward(
sys_uttr=system_uttr, dialog_history=dial_history
)

# Return the updated states and dialog history.
return system_uttr, belief_state, dial_history

Expand Down Expand Up @@ -199,13 +204,16 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
# NLG module.
template_nlg = TemplateNLGMultiWOZ()

# Updates dialog history with system utterance.
sys_utter_history_update = SystemUtteranceHistoryUpdate()

# Construct the "evaluation" (inference) neural graph by connecting the modules using nmTensors.
# Note: Using the same names for passed nmTensor as in the actual forward pass.
with NeuralGraph(operation_mode=OperationMode.evaluation) as dialog_pipeline:
# 1.1. User utterance encoder.
# Bind all the input ports of this module.
dialog_ids, dialog_lens, dial_history = user_utterance_encoder(
user_uttr=dialog_pipeline, sys_uttr=dialog_pipeline, dialog_history=dialog_pipeline,
user_uttr=dialog_pipeline, dialog_history=dialog_pipeline,
)
# Fire step 1: 1.2. TRADE encoder.
outputs, hidden = trade_encoder(inputs=dialog_ids, input_lens=dialog_lens)
Expand All @@ -229,6 +237,9 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
# 3. Forward pass throught Natural Language Generator module (Template-Based).
system_uttr = template_nlg(system_acts=system_acts)

# 4. Update dialog history with system utterance
dial_history = sys_utter_history_update(sys_uttr=system_uttr, dialog_history=dial_history)

# Show the graph summary.
logging.info(dialog_pipeline.summary())

Expand All @@ -250,7 +261,7 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
else:
# Pass the "user uterance" as inputs to the dialog pipeline.
system_uttr, belief_state, dial_history = forward(
dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state
dialog_pipeline, user_uttr, dial_history, belief_state
)

elif args.mode == 'example':
Expand All @@ -262,5 +273,5 @@ def forward(dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state)
for user_uttr in example:
logging.info("User utterance: %s", user_uttr)
system_uttr, belief_state, dial_history = forward(
dialog_pipeline, system_uttr, user_uttr, dial_history, belief_state
dialog_pipeline, user_uttr, dial_history, belief_state
)
3 changes: 3 additions & 0 deletions nemo/collections/nlp/nm/non_trainables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# =============================================================================

from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.rule_based_dpm_multiwoz import RuleBasedDPMMultiWOZ
from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.system_utterance_history_update import (
SystemUtteranceHistoryUpdate,
)
from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.template_nlg_multiwoz import TemplateNLGMultiWOZ
from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.trade_state_update_nm import TradeStateUpdateNM
from nemo.collections.nlp.nm.non_trainables.dialogue_state_tracking.user_utterance_encoder import UserUtteranceEncoder
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# =============================================================================
# Copyright 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

'''
This file contains code artifacts adapted from the original implementation:
https://github.com/thu-coai/ConvLab-2/
'''
import torch

from nemo.backends.pytorch.nm import NonTrainableNM
from nemo.core.neural_types import *
from nemo.utils import logging
from nemo.utils.decorators import add_port_docs

__all__ = ['SystemUtteranceHistoryUpdate']


class SystemUtteranceHistoryUpdate(NonTrainableNM):
"""
Updates dialogue history with system utterance.
"""

@property
@add_port_docs()
def input_ports(self):
"""Returns definitions of module input ports.
user_uttr (str): user utterance
sys_uttr (str): system utterace
dialog_history (list): dialogue history, list of system and diaglogue utterances
"""
return {
'sys_uttr': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=SystemUtterance()
),
'dialog_history': NeuralType(
axes=(AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True),),
elements_type=AgentUtterance(),
),
}

@property
@add_port_docs()
def output_ports(self):
"""Returns definitions of module output ports.
dialog_history (list): dialogue history, being a list of user and system utterances.
"""
return {
'dialog_history': NeuralType(
axes=(AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True),),
elements_type=AgentUtterance(),
),
}

def __init__(self):
"""
Initializes the object
Args:
data_desc (obj): data descriptor for MultiWOZ dataset, contains information about domains, slots,
and associated vocabulary
"""
super().__init__()

def forward(self, sys_uttr, dialog_history):
"""
Returns updated dialog history.
Args:
sys_uttr (str): system utterace
dialog_history (list): dialogue history, list of user and system diaglogue utterances
Returns:
dialog_history (list): updated dialogue history, list of user and system diaglogue utterances
"""
dialog_history.append(["sys", sys_uttr])
logging.debug("Dialogue history: %s", dialog_history)

return dialog_history
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def output_ports(self):
system_uttr (str): generated system's response
"""
return {
'sys_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()),
'sys_uttr': NeuralType(
axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=SystemUtterance()
),
}

def forward(self, system_acts):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

class UserUtteranceEncoder(NonTrainableNM):
"""
Encodes dialogue history (system and user utterances) into a Multiwoz dataset format
Updates dialogue history with user utterance and encodes the history (system and user utterances) into
a flat list of tokens (per sample).
Args:
data_desc (obj): data descriptor for MultiWOZ dataset, contains information about domains, slots,
and associated vocabulary
Expand All @@ -41,12 +42,10 @@ class UserUtteranceEncoder(NonTrainableNM):
def input_ports(self):
"""Returns definitions of module input ports.
user_uttr (str): user utterance
sys_uttr (str): system utterace
dialog_history (list): dialogue history, list of system and diaglogue utterances
"""
return {
'user_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()),
'sys_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=Utterance()),
'user_uttr': NeuralType(axes=[AxisType(kind=AxisKind.Batch, is_list=True)], elements_type=UserUtterance()),
'dialog_history': NeuralType(
axes=(AxisType(kind=AxisKind.Batch, is_list=True), AxisType(kind=AxisKind.Time, is_list=True),),
elements_type=AgentUtterance(),
Expand Down Expand Up @@ -81,20 +80,18 @@ def __init__(self, data_desc):
super().__init__()
self.data_desc = data_desc

def forward(self, user_uttr, sys_uttr, dialog_history):
def forward(self, user_uttr, dialog_history):
"""
Returns dialogue utterances in the format accepted by the TRADE Dialogue state tracking model
Args:
dialog_history (list): dialogue history, list of system and diaglogue utterances
user_uttr (str): user utterance
sys_uttr (str): system utterace
Returns:
dialog_ids (int): token ids for the whole dialogue history
dialog_lens (int): length of the whole tokenized dialogue history
dialog_history (list): updated dialogue history, list of system and diaglogue utterances
"""
# TODO: why we update sys utterance, whereas we have only user utterance at that point?
dialog_history.append(["sys", sys_uttr])
dialog_history.append(["user", user_uttr])
logging.debug("Dialogue history: %s", dialog_history)

Expand Down
18 changes: 14 additions & 4 deletions nemo/core/neural_types/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
'StringLabel',
'StringType',
'Utterance',
'UserUtterance',
'SystemUtterance',
'AgentUtterance',
'TokenIndex',
'Length',
Expand Down Expand Up @@ -258,14 +260,22 @@ class Utterance(StringType):
"""Element type representing an utterance (e.g. "Is there a train from Ely to Cambridge on Tuesday ?")."""


class UserUtterance(Utterance):
"""Element type representing a utterance expresesd by the user."""


class SystemUtterance(Utterance):
"""Element type representing an utterance produced by the system."""


class AgentUtterance(ElementType):
"""Element type representing utterance returned by an agent (user or system) participating in a dialog."""

def __str__(self):
return "Utterance returned by an agent (user or system) participating in a dialog."
# def __str__(self):
# return "Utterance returned by an agent (user or system) participating in a dialog."

def fields(self):
return ("agent", "utterance")
# def fields(self):
# return ("agent", "utterance")


class IntType(ElementType):
Expand Down

0 comments on commit 9d97337

Please sign in to comment.