Skip to content

Commit

Permalink
s2s CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
donglixp authored Apr 2, 2020
1 parent 3fde4a2 commit 1364126
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion s2s-ft/s2s_ft/modeling_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ def get_dup_ngram_candidates(seq, n):
forbid_word_mask = torch.tensor(
buf_matrix, dtype=log_scores.dtype)
forbid_word_mask = torch.reshape(
forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda()
forbid_word_mask, [batch_size * K, 1, vocab_size]).to(input_ids.device)
else:
forbid_word_mask = None
next_pos += 1
Expand Down

0 comments on commit 1364126

Please sign in to comment.