Skip to content

Commit

Permalink
fixed TRADE training
Browse files Browse the repository at this point in the history
Signed-off-by: Evelina Bakhturina <ebakhturina@nvidia.com>
  • Loading branch information
ekmb committed Jun 4, 2020
1 parent db3d410 commit f9f1d45
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@

import nemo.core as nemo_core
from nemo import logging
from nemo.backends.pytorch.common import EncoderRNN
from nemo.backends.pytorch.common.losses import CrossEntropyLossNM, LossAggregatorNM
from nemo.collections.nlp.callbacks.state_tracking_trade_callback import eval_epochs_done_callback, eval_iter_callback
from nemo.collections.nlp.data.datasets.multiwoz_dataset import MultiWOZDataDesc
from nemo.collections.nlp.nm.data_layers import MultiWOZDataLayer
from nemo.collections.nlp.nm.losses import MaskedLogLoss
from nemo.collections.nlp.nm.trainables import TRADEGenerator
from nemo.collections.nlp.nm.trainables import EncoderRNN, TRADEGenerator
from nemo.utils.lr_policies import get_lr_policy

parser = argparse.ArgumentParser(description='Dialogue state tracking with TRADE model on MultiWOZ dataset')
Expand Down Expand Up @@ -155,8 +154,8 @@ def create_pipeline(num_samples, batch_size, num_gpus, input_dropout, data_prefi
point_outputs, gate_outputs = decoder(
encoder_hidden=hidden,
encoder_outputs=outputs,
input_lens=input_data.src_lens,
src_ids=input_data.src_ids,
dialog_lens=input_data.src_lens,
dialog_ids=input_data.src_ids,
targets=input_data.tgt_ids,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import nemo
from nemo.collections.nlp.data.datasets.multiwoz_dataset import MultiWOZDataset
from nemo.collections.nlp.nm.data_layers.text_datalayer import TextDataLayer
from nemo.core.neural_types import ChannelType, LabelsType, LengthsType, NeuralType
from nemo.core.neural_types import LabelsType, Length, NeuralType, TokenIndex
from nemo.utils.decorators import add_port_docs

__all__ = ['MultiWOZDataLayer']
Expand Down Expand Up @@ -83,10 +83,10 @@ def output_ports(self):
"""
return {
"src_ids": NeuralType(('B', 'T'), ChannelType()),
"src_lens": NeuralType(tuple('B'), LengthsType()),
"src_ids": NeuralType(('B', 'T'), TokenIndex()),
"src_lens": NeuralType(tuple('B'), Length()),
"tgt_ids": NeuralType(('B', 'D', 'T'), LabelsType()),
"tgt_lens": NeuralType(('B', 'D'), LengthsType()),
"tgt_lens": NeuralType(('B', 'D'), Length()),
"gating_labels": NeuralType(('B', 'D'), LabelsType()),
"turn_domain": NeuralType(),
}
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/nm/losses/masked_xentropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import torch

from nemo.backends.pytorch.nm import LossNM
from nemo.core.neural_types import LabelsType, LengthsType, LogitsType, LossType, NeuralType
from nemo.core.neural_types import LabelsType, Length, LogitsType, LossType, NeuralType
from nemo.utils.decorators import add_port_docs

__all__ = ['MaskedLogLoss']
Expand Down Expand Up @@ -72,7 +72,7 @@ def input_ports(self):
return {
"logits": NeuralType(('B', 'T', 'D', 'D'), LogitsType()),
"labels": NeuralType(('B', 'D', 'T'), LabelsType()),
"length_mask": NeuralType(('B', 'D'), LengthsType()),
"length_mask": NeuralType(('B', 'D'), Length()),
}

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def input_ports(self):
"""Returns definitions of module input ports.
"""
return {
'encoder_hidden': NeuralType(('B', 'T', 'C'), ChannelType()),
'encoder_outputs': NeuralType(('B', 'T', 'C'), ChannelType()),
'encoder_hidden': NeuralType(('B', 'T', 'D'), ChannelType()),
'encoder_outputs': NeuralType(('B', 'T', 'D'), ChannelType()),
'dialog_ids': NeuralType(('B', 'T'), elements_type=TokenIndex()),
'dialog_lens': NeuralType(tuple('B'), elements_type=Length()),
'targets': NeuralType(('B', 'D', 'T'), LabelsType(), optional=True),
Expand Down

0 comments on commit f9f1d45

Please sign in to comment.