-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
124 lines (99 loc) · 4.28 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CS224N 2018-19: Homework 5
nmt.py: NMT Model
Pencheng Yin <pcyin@cs.cmu.edu>
Sahil Chopra <schopra8@stanford.edu>
"""
import math
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def pad_sents_char(sents, char_pad_token):
""" Pad list of sentences according to the longest sentence in the batch and max_word_length.
@param sents (list[list[list[int]]]): list of sentences, result of `words2charindices()`
from `vocab.py`
@param char_pad_token (int): index of the character-padding token
@returns sents_padded (list[list[list[int]]]): list of sentences where sentences/words shorter
than the max length sentence/word are padded out with the appropriate pad token, such that
each sentence in the batch now has same number of words and each word has an equal
number of characters
Output shape: (batch_size, max_sentence_length, max_word_length)
"""
# Words longer than 21 characters should be truncated
max_word_length = 21
### YOUR CODE HERE for part 1f
maxlen = max(len(s) for s in sents)
sents_padded = []
for s in sents:
words = []
for w in s:
wpd = [char_pad_token]*max_word_length
wpd[:len(w[:max_word_length])] = w[:max_word_length]
words.append(wpd)
pd = [[char_pad_token] * max_word_length] *maxlen
pd[:len(s)] = words
sents_padded.append(pd)
### TODO:
### Perform necessary padding to the sentences in the batch similar to the pad_sents()
### method below using the padding character from the arguments. You should ensure all
### sentences have the same number of words and each word has the same number of
### characters.
### Set padding words to a `max_word_length` sized vector of padding characters.
###
### You should NOT use the method `pad_sents()` below because of the way it handles
### padding and unknown words.
### END YOUR CODE
return sents_padded
def pad_sents(sents, pad_token):
""" Pad list of sentences according to the longest sentence in the batch.
@param sents (list[list[int]]): list of sentences, where each sentence
is represented as a list of words
@param pad_token (int): padding token
@returns sents_padded (list[list[int]]): list of sentences where sentences shorter
than the max length sentence are padded out with the pad_token, such that
each sentences in the batch now has equal length.
Output shape: (batch_size, max_sentence_length)
"""
sents_padded = []
max_len = max(len(s) for s in sents)
batch_size = len(sents)
for s in sents:
padded = [pad_token] * max_len
padded[:len(s)] = s
sents_padded.append(padded)
return sents_padded
def read_corpus(file_path, source):
""" Read file, where each sentence is dilineated by a `\n`.
@param file_path (str): path to file containing corpus
@param source (str): "tgt" or "src" indicating whether text
is of the source language or target language
"""
data = []
for line in open(file_path):
sent = line.strip().split(' ')
# only append <s> and </s> to the target sentence
if source == 'tgt':
sent = ['<s>'] + sent + ['</s>']
data.append(sent)
return data
def batch_iter(data, batch_size, shuffle=False):
""" Yield batches of source and target sentences reverse sorted by length (largest to smallest).
@param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
@param batch_size (int): batch size
@param shuffle (boolean): whether to randomly shuffle the dataset
"""
batch_num = math.ceil(len(data) / batch_size)
index_array = list(range(len(data)))
if shuffle:
np.random.shuffle(index_array)
for i in range(batch_num):
indices = index_array[i * batch_size: (i + 1) * batch_size]
examples = [data[idx] for idx in indices]
examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
src_sents = [e[0] for e in examples]
tgt_sents = [e[1] for e in examples]
yield src_sents, tgt_sents