Skip to content

Commit

Permalink
fix unittests
Browse files Browse the repository at this point in the history
Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>
  • Loading branch information
okuchaiev committed Feb 11, 2020
1 parent 781ebc5 commit 51120d9
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion nemo/backends/pytorch/common/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def input_ports(self):
"""
return {
# 'targets': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
'targets': NeuralType(ChannelType(), ('B', 'T')),
'targets': NeuralType(LabelsType(), ('B', 'T')),
# 'encoder_outputs': NeuralType(
# {0: AxisType(BatchTag), 1: AxisType(TimeTag), 2: AxisType(ChannelTag),}, optional=True,
# ),
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def output_ports(self):
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
'audio_signal': NeuralType(AudioSignal(freq=self._sample_rate), ('B', 'T')),
'a_sig_length': NeuralType(LengthsType(), tuple('B')),
'transcripts': NeuralType(ChannelType(), ('B', 'T')),
'transcripts': NeuralType(LabelsType(), ('B', 'T')),
'transcript_length': NeuralType(LengthsType(), tuple('B')),
}

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def input_ports(self):
# "input_length": NeuralType({0: AxisType(BatchTag)}),
# "target_length": NeuralType({0: AxisType(BatchTag)}),
"log_probs": NeuralType(LogprobsType(), ('B', 'T', 'D')),
"targets": NeuralType(PredictionsType(), ('B', 'T')),
"targets": NeuralType(LabelsType(), ('B', 'T')),
"input_length": NeuralType(LengthsType(), tuple('B')),
"target_length": NeuralType(LengthsType(), tuple('B')),
}
Expand Down
2 changes: 1 addition & 1 deletion tests/asr/test_zeroDS.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_asr_with_zero_ds(self):
(AxisType(AxisKind.Batch), AxisType(AxisKind.Dimension, 64), AxisType(AxisKind.Time, 64)),
),
"processed_length": NeuralType(LengthsType(), tuple('B')),
"transcript": NeuralType(ChannelType(), (AxisType(AxisKind.Batch), AxisType(AxisKind.Time, 64))),
"transcript": NeuralType(LabelsType(), (AxisType(AxisKind.Batch), AxisType(AxisKind.Time, 64))),
"transcript_length": NeuralType(LengthsType(), tuple('B')),
},
)
Expand Down

0 comments on commit 51120d9

Please sign in to comment.