diff --git a/examples/nlp/dialogue_state_tracking/dialogue_state_tracking_trade.py b/examples/nlp/dialogue_state_tracking/dialogue_state_tracking_trade.py index 2059a843de29..88e7c24c1b6a 100644 --- a/examples/nlp/dialogue_state_tracking/dialogue_state_tracking_trade.py +++ b/examples/nlp/dialogue_state_tracking/dialogue_state_tracking_trade.py @@ -30,13 +30,12 @@ import nemo.core as nemo_core from nemo import logging -from nemo.backends.pytorch.common import EncoderRNN from nemo.backends.pytorch.common.losses import CrossEntropyLossNM, LossAggregatorNM from nemo.collections.nlp.callbacks.state_tracking_trade_callback import eval_epochs_done_callback, eval_iter_callback from nemo.collections.nlp.data.datasets.multiwoz_dataset import MultiWOZDataDesc from nemo.collections.nlp.nm.data_layers import MultiWOZDataLayer from nemo.collections.nlp.nm.losses import MaskedLogLoss -from nemo.collections.nlp.nm.trainables import TRADEGenerator +from nemo.collections.nlp.nm.trainables import EncoderRNN, TRADEGenerator from nemo.utils.lr_policies import get_lr_policy parser = argparse.ArgumentParser(description='Dialogue state tracking with TRADE model on MultiWOZ dataset') @@ -155,8 +154,8 @@ def create_pipeline(num_samples, batch_size, num_gpus, input_dropout, data_prefi point_outputs, gate_outputs = decoder( encoder_hidden=hidden, encoder_outputs=outputs, - input_lens=input_data.src_lens, - src_ids=input_data.src_ids, + dialog_lens=input_data.src_lens, + dialog_ids=input_data.src_ids, targets=input_data.tgt_ids, ) diff --git a/nemo/collections/nlp/nm/data_layers/state_tracking_trade_datalayer.py b/nemo/collections/nlp/nm/data_layers/state_tracking_trade_datalayer.py index ab0c5edb161e..956f56f0b52b 100644 --- a/nemo/collections/nlp/nm/data_layers/state_tracking_trade_datalayer.py +++ b/nemo/collections/nlp/nm/data_layers/state_tracking_trade_datalayer.py @@ -43,7 +43,7 @@ import nemo from nemo.collections.nlp.data.datasets.multiwoz_dataset import MultiWOZDataset from nemo.collections.nlp.nm.data_layers.text_datalayer import TextDataLayer -from nemo.core.neural_types import ChannelType, LabelsType, LengthsType, NeuralType +from nemo.core.neural_types import LabelsType, Length, NeuralType, TokenIndex from nemo.utils.decorators import add_port_docs __all__ = ['MultiWOZDataLayer'] @@ -83,10 +83,10 @@ def output_ports(self): """ return { - "src_ids": NeuralType(('B', 'T'), ChannelType()), - "src_lens": NeuralType(tuple('B'), LengthsType()), + "src_ids": NeuralType(('B', 'T'), TokenIndex()), + "src_lens": NeuralType(tuple('B'), Length()), "tgt_ids": NeuralType(('B', 'D', 'T'), LabelsType()), - "tgt_lens": NeuralType(('B', 'D'), LengthsType()), + "tgt_lens": NeuralType(('B', 'D'), Length()), "gating_labels": NeuralType(('B', 'D'), LabelsType()), "turn_domain": NeuralType(), } diff --git a/nemo/collections/nlp/nm/losses/masked_xentropy_loss.py b/nemo/collections/nlp/nm/losses/masked_xentropy_loss.py index 6d876ea752d7..0160564d5731 100644 --- a/nemo/collections/nlp/nm/losses/masked_xentropy_loss.py +++ b/nemo/collections/nlp/nm/losses/masked_xentropy_loss.py @@ -39,7 +39,7 @@ import torch from nemo.backends.pytorch.nm import LossNM -from nemo.core.neural_types import LabelsType, LengthsType, LogitsType, LossType, NeuralType +from nemo.core.neural_types import LabelsType, Length, LogitsType, LossType, NeuralType from nemo.utils.decorators import add_port_docs __all__ = ['MaskedLogLoss'] @@ -72,7 +72,7 @@ def input_ports(self): return { "logits": NeuralType(('B', 'T', 'D', 'D'), LogitsType()), "labels": NeuralType(('B', 'D', 'T'), LabelsType()), - "length_mask": NeuralType(('B', 'D'), LengthsType()), + "length_mask": NeuralType(('B', 'D'), Length()), } @property diff --git a/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/trade_generator_nm.py b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/trade_generator_nm.py index 9d7df7928944..08c214ce5018 100644 --- a/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/trade_generator_nm.py +++ b/nemo/collections/nlp/nm/trainables/dialogue_state_tracking/trade_generator_nm.py @@ -70,8 +70,8 @@ def input_ports(self): """Returns definitions of module input ports. """ return { - 'encoder_hidden': NeuralType(('B', 'T', 'C'), ChannelType()), - 'encoder_outputs': NeuralType(('B', 'T', 'C'), ChannelType()), + 'encoder_hidden': NeuralType(('B', 'T', 'D'), ChannelType()), + 'encoder_outputs': NeuralType(('B', 'T', 'D'), ChannelType()), 'dialog_ids': NeuralType(('B', 'T'), elements_type=TokenIndex()), 'dialog_lens': NeuralType(tuple('B'), elements_type=Length()), 'targets': NeuralType(('B', 'D', 'T'), LabelsType(), optional=True),