diff --git a/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py b/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py index 69c4db4e9211..018301364a89 100644 --- a/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py +++ b/examples/nlp/dialogue_state_tracking/rule_based_policy_multiwoz.py @@ -128,17 +128,18 @@ def init_session(): return '', '', default_state() -def get_system_responce(user_uttr, system_uttr, dialog_history, state): +def get_system_response(user_uttr, system_uttr, dialog_history, state): """ - Returns system reply by passing user utterance through TRADE Dialogue State Tracker, then the output of the TRADE model to the - Rule-base Dialogue Policy Magager and the output of the Policy Manager to the Rule-based Natural language generation module + Returns system reply by passing system and user utterances (dialogue history) through the TRADE Dialogue State Tracker, + then the output of the TRADE model goes to the Rule-base Dialogue Policy Magager + and the output of the Policy Manager goes to the Rule-based Natural language generation module Args: - user_uttr(str): User utterance - system_uttr(str): Previous system utterance - dialog_history(str): Diaglogue history contains all previous system and user utterances + user_uttr (str): User utterance + system_uttr (str): Previous system utterance + dialog_history (str): Diaglogue history contains all previous system and user utterances state (dict): dialogue state Returns: - system_utterance(str): system response + system_utterance (str): system response state (dict): updated dialogue state """ src_ids, src_lens = utterance_encoder.forward(state=state, user_uttr=user_uttr, sys_uttr=system_uttr) @@ -177,7 +178,7 @@ def get_system_responce(user_uttr, system_uttr, dialog_history, state): system_uttr, dialog_history, state = init_session() logging.info("============ Starting a new dialogue ============") else: - get_system_responce(user_uttr, system_uttr, dialog_history, state) + get_system_response(user_uttr, system_uttr, dialog_history, state) elif args.mode == 'example': for example in examples: @@ -185,4 +186,4 @@ def get_system_responce(user_uttr, system_uttr, dialog_history, state): system_uttr, dialog_history, state = init_session() for user_uttr in example: logging.info("User utterance: %s", user_uttr) - system_uttr, state = get_system_responce(user_uttr, system_uttr, dialog_history, state) + system_uttr, state = get_system_response(user_uttr, system_uttr, dialog_history, state) diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_multiwoz_bot.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_multiwoz_bot.py index d48f91d0f241..44ad151cac2f 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_multiwoz_bot.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/rule_based_multiwoz_bot.py @@ -203,8 +203,6 @@ def forward(self, state): for user_act_ in user_acts: del DA[user_act_] - # print("Sys action: ", DA) - if DA == {}: DA = {'general-greet': [['none', 'none']]} tuples = [] @@ -284,12 +282,6 @@ def _update_DA(self, user_act, user_action, state, DA): kb_result = self.db.query(domain.lower(), constraints) self.kb_result[domain] = deepcopy(kb_result) - # print("\tConstraint: " + "{}".format(constraints)) - # print("\tCandidate Count: " + "{}".format(len(kb_result))) - # if len(kb_result) > 0: - # print("Candidate: " + "{}".format(kb_result[0])) - - # print(state['user_action']) # Respond to user's request if intent_type == 'Request': if self.recommend_flag > 1: @@ -310,7 +302,6 @@ def _update_DA(self, user_act, user_action, state, DA): else: # There's no result matching user's constraint - # if len(state['kb_results_dict']) == 0: if len(kb_result) == 0: if (domain + "-NoOffer") not in DA: DA[domain + "-NoOffer"] = [] @@ -468,8 +459,6 @@ def _update_train(self, user_act, user_action, state, DA): kb_result = self.db.query('train', constraints) self.kb_result['Train'] = deepcopy(kb_result) - # print(constraints) - # print(len(kb_result)) if user_act == 'Train-Request': del DA['Train-Request'] if 'Train-Inform' not in DA: diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py index 2447b05f7662..c7904d785b61 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/trade_state_update_nm.py @@ -16,7 +16,8 @@ ''' This file contains code artifacts adapted from the original implementation: -https://github.com/thu-coai/ConvLab-2/blob/master/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py +https://github.com/thu-coai/ConvLab-2/blob/master/convlab2/dst/trade/multiwoz/trade.py +https://github.com/thu-coai/ConvLab-2 ''' import copy import re @@ -33,76 +34,6 @@ __all__ = ['TradeStateUpdateNM'] -# class UtteranceEncoderNM(NonTrainableNM): -# """ -# Encodes dialogue history (system and user utterances) into a Multiwoz dataset format -# Args: -# data_desc (obj): data descriptor for MultiWOZ dataset, contains information about domains, slots, -# and associated vocabulary -# """ - -# @property -# @add_port_docs() -# def input_ports(self): -# """Returns definitions of module input ports. -# state (dict): dialogue state dictionary - see nemo.collections.nlp.data.datasets.multiwoz_dataset.state -# for the format -# user_uttr (str): user utterance -# sys_uttr (str): system utterace -# """ -# return { -# "state": NeuralType(axes=tuple('ANY'), element_type=VoidType()), -# "user_uttr": NeuralType(axes=tuple('ANY'), element_type=VoidType()), -# "sys_uttr": NeuralType(axes=tuple('ANY'), element_type=VoidType()) -# } - - -# @property -# @add_port_docs() -# def output_ports(self): -# """Returns definitions of module output ports. -# src_ids (int): token ids for dialogue history -# src_lens (int): length of the tokenized dialogue history -# """ -# return { -# 'src_ids': NeuralType(('B', 'T'), element_type=ChannelType()), -# 'src_lens': NeuralType(tuple('B'), elemenet_type=LengthsType()), -# } - -# def __init__(self, data_desc): -# """ -# Init -# Args: -# data_desc (obj): data descriptor for MultiWOZ dataset, contains information about domains, slots, -# and associated vocabulary -# """ -# super().__init__() -# self.data_desc = data_desc - -# def forward(self, state, user_uttr, sys_uttr): -# """ -# Returns dialogue utterances in the format accepted by the TRADE Dialogue state tracking model -# Args: -# state (dict): state dictionary - see nemo.collections.nlp.data.datasets.multiwoz_dataset.state -# for the format -# user_uttr (str): user utterance -# sys_uttr (str): system utterace -# Returns: -# src_ids (int): token ids for dialogue history -# src_lens (int): length of the tokenized dialogue history -# """ -# state["history"].append(["sys", sys_uttr]) -# state["history"].append(["user", user_uttr]) -# state["user_action"] = user_uttr -# logging.debug("Dialogue state: %s", state) - -# context = ' ; '.join([item[1].strip().lower() for item in state['history']]).strip() + ' ;' -# context_ids = self.data_desc.vocab.tokens2ids(context.split()) -# src_ids = torch.tensor(context_ids).unsqueeze(0).to(self._device) -# src_lens = torch.tensor(len(context_ids)).unsqueeze(0).to(self._device) -# return src_ids, src_lens - - class TradeStateUpdateNM(NonTrainableNM): """ Takes the predictions of the TRADE Dialogue state tracking model, @@ -288,9 +219,7 @@ def normalize_value(self, value_set, domain, slot, value): if slot not in value_set[domain]: logging.warning('slot {} no in domain {}'.format(slot, domain)) return value - # raise Exception( - # 'slot <{}> not found in db_values[{}]'.format( - # slot, domain)) + value_list = value_set[domain][slot] # exact match or containing match v = self._match_or_contain(value, value_list) diff --git a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/utterance_encoder_nm.py b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/utterance_encoder_nm.py index b7d1ae3ef1fb..ba8d56bf29f6 100755 --- a/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/utterance_encoder_nm.py +++ b/nemo/collections/nlp/nm/non_trainables/dialogue_state_tracking/utterance_encoder_nm.py @@ -16,7 +16,7 @@ ''' This file contains code artifacts adapted from the original implementation: -https://github.com/thu-coai/ConvLab-2/blob/master/convlab2/policy/rule/multiwoz/rule_based_multiwoz_bot.py +https://github.com/thu-coai/ConvLab-2/ ''' import copy import re @@ -100,352 +100,3 @@ def forward(self, state, user_uttr, sys_uttr): src_ids = torch.tensor(context_ids).unsqueeze(0).to(self._device) src_lens = torch.tensor(len(context_ids)).unsqueeze(0).to(self._device) return src_ids, src_lens - - -# class TradeOutputNM(NonTrainableNM): -# """ -# Takes the predictions of the TRADE Dialogue state tracking model, -# generates human-readable model output and updates the dialogue -# state with the TRADE predcitions -# """ - -# @property -# @add_port_docs() -# def input_ports(self): -# """Returns definitions of module input ports. -# """ -# return { -# 'point_outputs_pred': NeuralType(('B', 'T', 'D', 'D'), LogitsType()), -# 'gating_preds': NeuralType(('B', 'D', 'D'), LogitsType()) -# } - -# @property -# @add_port_docs() -# def output_ports(self): -# """Returns definitions of module output ports. - -# """ -# return { -# 'state': NeuralType( -# axes=(AxisType(kind=AxisKind.Any, is_list=True), -# ), elements_type=VoidType(), -# )} - -# def __init__(self, data_desc): -# """ -# Init -# Args: -# data_desc (obj): data descriptor for MultiWOZ dataset, contains information about domains, slots, -# and associated vocabulary -# """ -# super().__init__() -# self.data_desc = data_desc - - -# def forward(self, state, gating_preds, point_outputs_pred): -# """ -# Processes the TRADE model output and updates the dialogue state with the model's predictions -# Args: -# state (dict): -# gating_preds (float): TRADE model gating predictions -# point_outputs_pred (float): TRADE model pointers predictions -# """ -# prev_state = state -# gate_outputs_max, point_outputs_max = self.get_trade_prediction(gating_preds, point_outputs_pred) -# trade_output = self.get_human_readable_output(gate_outputs_max, point_outputs_max)[0] -# logging.debug('TRADE output: %s', trade_output) - -# new_belief_state = self.reformat_belief_state( -# trade_output, copy.deepcopy(prev_state['belief_state']), self.data_desc.ontology_value_dict -# ) -# state['belief_state'] = new_belief_state - -# # update request state based on the latest user utterance -# new_request_state = copy.deepcopy(state['request_state']) -# # extract current user output -# user_uttr = state['user_action'] -# user_request_slot = self.detect_requestable_slots(user_uttr.lower(), self.data_desc.det_dict) -# for domain in user_request_slot: -# for key in user_request_slot[domain]: -# if domain not in new_request_state: -# new_request_state[domain] = {} -# if key not in new_request_state[domain]: -# new_request_state[domain][key] = user_request_slot[domain][key] -# state['request_state'] = new_request_state -# return state - -# def get_trade_prediction(self, gating_preds, point_outputs_pred): -# """ -# Takes argmax of the model's predictions -# Args: -# gating_preds (float): TRADE model gating predictions -# point_outputs_pred (float): TRADE model output, contains predicted pointers -# Returns: -# gate_outputs_max_list (array): list of the gating predicions -# point_outputs_max_list (list of arrays): each array contains the pointers predictions -# """ -# p_max = torch.argmax(point_outputs_pred, dim=-1) -# point_outputs_max = [tensor2numpy(p_max)] -# g_max = torch.argmax(gating_preds, axis=-1) -# gate_outputs_max = tensor2numpy(g_max) -# return gate_outputs_max, point_outputs_max - -# def get_human_readable_output(self, gating_preds, point_outputs_pred): -# """ -# Returns trade output in the human readable format -# Args: -# gating_preds (array): an array of gating predictions, TRADE model output -# point_outputs_pred (list of arrays): TRADE model output, contains predicted pointers -# Returns: -# output (list of strings): TRADE model output, each values represents domain-slot_name-slot_value, -# for example, ['hotel-pricerange-cheap', 'hotel-type-hotel'] -# """ -# slots = self.data_desc.slots -# bi = 0 -# predict_belief_bsz_ptr = [] -# inverse_unpoint_slot = dict([(v, k) for k, v in self.data_desc.gating_dict.items()]) - -# for si, sg in enumerate(gating_preds[bi]): -# if sg == self.data_desc.gating_dict["none"]: -# continue -# elif sg == self.data_desc.gating_dict["ptr"]: -# pred = point_outputs_pred[0][0][si] - -# pred = [self.data_desc.vocab.idx2word[x] for x in pred] - -# st = [] -# for e in pred: -# if e == 'EOS': -# break -# else: -# st.append(e) -# st = " ".join(st) -# if st == "none": -# continue -# else: -# predict_belief_bsz_ptr.append(slots[si] + "-" + str(st)) -# else: -# predict_belief_bsz_ptr.append(slots[si] + "-" + inverse_unpoint_slot[sg]) -# # predict_belief_bsz_ptr ['hotel-pricerange-cheap', 'hotel-type-hotel'] -# output = [predict_belief_bsz_ptr] -# return output - -# def reformat_belief_state(self, raw_state, bs, value_dict): -# ''' -# Reformat TRADE model raw state into the default_state format -# Args: -# raw_state(list of strings): raw TRADE model output/state, each values represents domain-slot_name-slot_value, -# for example, ['hotel-pricerange-cheap', 'hotel-type-hotel'] -# bs (dict): belief state - see nemo.collections.nlp.data.datasets.multiwoz.default_state for the format -# value_dict (dict): a dictionary of all slot values for MultiWOZ dataset -# Returns: -# bs (dict): reformatted belief state -# ''' -# for item in raw_state: -# item = item.lower() -# slist = item.split('-', 2) -# domain = slist[0].strip() -# slot = slist[1].strip() -# value = slist[2].strip() -# if domain not in bs: -# raise Exception('Error: domain <{}> not in belief state'.format(domain)) -# dbs = bs[domain] -# assert 'semi' in dbs -# assert 'book' in dbs -# slot = REF_SYS_DA[domain.capitalize()].get(slot, slot) -# # reformat some slots -# if slot == 'arriveby': -# slot = 'arriveBy' -# elif slot == 'leaveat': -# slot = 'leaveAt' -# if slot in dbs['semi']: -# dbs['semi'][slot] = self.normalize_value(value_dict, domain, slot, value) -# elif slot in dbs['book']: -# dbs['book'][slot] = value -# elif slot.lower() in dbs['book']: -# dbs['book'][slot.lower()] = value -# else: -# logging.warning( -# 'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(slot, value, domain, item) -# ) -# return bs - -# def normalize_value(self, value_set, domain, slot, value): -# """Normalized the value produced by NLU module to map it to the ontology value space. -# Args: -# value_set (dict): The value set of task ontology. -# domain (str): The domain of the slot-value pairs. -# slot (str): The slot of the value. -# value (str): The raw value detected by NLU module. -# Returns: -# value (str): The normalized value, which fits with the domain ontology. -# """ -# slot = slot.lower() -# value = value.lower() -# value = ' '.join(value.split()) -# try: -# assert domain in value_set -# except: -# raise Exception('domain <{}> not found in value set'.format(domain)) -# if slot not in value_set[domain]: -# logging.warning('slot {} no in domain {}'.format(slot, domain)) -# return value -# # raise Exception( -# # 'slot <{}> not found in db_values[{}]'.format( -# # slot, domain)) -# value_list = value_set[domain][slot] -# # exact match or containing match -# v = self._match_or_contain(value, value_list) -# if v is not None: -# return v -# # some transfomations -# cand_values = self._transform_value(value) -# for cv in cand_values: -# v = self._match_or_contain(cv, value_list) -# if v is not None: -# logging.warning('slot value found via _match_or_contain') -# return v -# # special value matching -# v = self.special_match(domain, slot, value) -# if v is not None: -# logging.warning('slot value found via special_match') -# return v -# logging.warning('Failed: domain {} slot {} value {}, raw value returned.'.format(domain, slot, value)) -# return value - -# def _transform_value(self, value): -# """makes clean up value transformations""" -# cand_list = [] -# # a 's -> a's -# if " 's" in value: -# cand_list.append(value.replace(" 's", "'s")) -# # a - b -> a-b -# if " - " in value: -# cand_list.append(value.replace(" - ", "-")) -# # center <-> centre -# if value == 'center': -# cand_list.append('centre') -# elif value == 'centre': -# cand_list.append('center') -# # the + value -# if not value.startswith('the '): -# cand_list.append('the ' + value) -# return cand_list - -# def minDistance(self, word1, word2): -# """The minimum edit distance between word 1 and 2.""" -# if not word1: -# return len(word2 or '') or 0 -# if not word2: -# return len(word1 or '') or 0 -# size1 = len(word1) -# size2 = len(word2) -# tmp = list(range(size2 + 1)) -# value = None -# for i in range(size1): -# tmp[0] = i + 1 -# last = i -# for j in range(size2): -# if word1[i] == word2[j]: -# value = last -# else: -# value = 1 + min(last, tmp[j], tmp[j + 1]) -# last = tmp[j + 1] -# tmp[j + 1] = value -# return value - -# def _match_or_contain(self, value, value_list): -# """ -# Matches value by exact match or containing -# Args: -# value (str): slot value -# value_list (list of str): list of possible slot_values -# Returns: -# matched value -# """ -# if value in value_list: -# return value -# for v in value_list: -# if v in value or value in v: -# return v -# # fuzzy match, when len(value) is large and distance(v1, v2) is small -# for v in value_list: -# d = self.minDistance(value, v) -# if (d <= 2 and len(value) >= 10) or (d <= 3 and len(value) >= 15): -# return v -# return None - - -# def special_match(self, domain, slot, value): -# """special slot fuzzy matching""" -# matched_result = None -# if slot == 'arriveby' or slot == 'leaveat': -# matched_result = self._match_time(value) -# elif slot == 'price' or slot == 'entrance fee': -# matched_result = self._match_pound_price(value) -# elif slot == 'trainid': -# matched_result = self._match_trainid(value) -# elif slot == 'duration': -# matched_result = self._match_duration(value) -# return matched_result - - -# def _match_time(self, value): -# """Returns the time (leaveby, arriveat) in value, None if no time in value.""" -# mat = re.search(r"(\d{1,2}:\d{1,2})", value) -# if mat is not None and len(mat.groups()) > 0: -# return mat.groups()[0] -# return None - - -# def _match_trainid(self, value): -# """Returns the trainID in value, None if no trainID.""" -# mat = re.search(r"TR(\d{4})", value) -# if mat is not None and len(mat.groups()) > 0: -# return mat.groups()[0] -# return None - - -# def _match_pound_price(self, value): -# """Return the price with pounds in value, None if no trainID.""" -# mat = re.search(r"(\d{1,2},\d{1,2} pounds)", value) -# if mat is not None and len(mat.groups()) > 0: -# return mat.groups()[0] -# mat = re.search(r"(\d{1,2} pounds)", value) -# if mat is not None and len(mat.groups()) > 0: -# return mat.groups()[0] -# if "1 pound" in value.lower(): -# return '1 pound' -# if 'free' in value: -# return 'free' -# return None - - -# def _match_duration(self, value): -# """Return the durations (by minute) in value, None if no trainID.""" -# mat = re.search(r"(\d{1,2} minutes)", value) -# if mat is not None and len(mat.groups()) > 0: -# return mat.groups()[0] -# return None - - -# def detect_requestable_slots(self, observation, det_dic): -# """ -# Finds slot values in the observation (user utterance) and adds the to the requested slots list -# Args: -# observation (str): user utterance -# det_dic (dict): a dictionary of slot_name + (slot_name_domain) value pairs from user dialogue acts -# Returns: -# result (dict): of the requested slots in a format: {domain: {slot_name}: 0} -# """ -# result = {} -# observation = observation.lower() -# _observation = ' {} '.format(observation) -# for value in det_dic.keys(): -# _value = ' {} '.format(value.strip()) -# if _value in _observation: -# key, domain = det_dic[value].split('-') -# if domain not in result: -# result[domain] = {} -# result[domain][key] = 0 -# return result