-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Wrapper for FastText #847
Changes from 39 commits
55a4fc9
e916f7e
e5416ed
e64766b
c34cf37
c9b31f9
a0329af
0c0e2fa
cdefeb0
1aec5a2
e7368a3
fe283c2
9b36bc4
dfe1893
4a03f20
09b6ebe
7df4138
4c54d9b
a28f9f1
bf1182e
5a6b97b
cfb2e1c
b002765
27c0a14
81f8cbb
aa7e632
c780b9b
ccf5a47
708113b
b7de266
4d3d251
f2d13ce
3777423
6e20834
564ea0d
caeb275
784ffbf
20fe6f2
3b9483b
f5cdfb6
700dd26
d30ea56
bb6e538
c7a5d07
734057b
56d89e9
dc51096
bb48663
b58dd53
461a6b4
9137090
e5ae899
b98b40f
5eb8f75
27bec7b
ef0e1e2
ab07ef9
2f37b04
b2ff794
7b0874a
a7bceb6
dee9f97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,15 @@ def save(self, *args, **kwargs): | |
kwargs['ignore'] = kwargs.get('ignore', ['syn0norm']) | ||
super(KeyedVectors, self).save(*args, **kwargs) | ||
|
||
def word_vec(self, word, use_norm=False): | ||
if word in self.vocab: | ||
if use_norm: | ||
return self.syn0norm[self.vocab[word].index] | ||
else: | ||
return self.syn0[self.vocab[word].index] | ||
else: | ||
raise KeyError("word '%s' not in vocabulary" % word) | ||
|
||
def most_similar(self, positive=[], negative=[], topn=10, restrict_vocab=None, indexer=None): | ||
""" | ||
Find the top-N most similar words. Positive words contribute positively towards the | ||
|
@@ -89,11 +98,10 @@ def most_similar(self, positive=[], negative=[], topn=10, restrict_vocab=None, i | |
for word, weight in positive + negative: | ||
if isinstance(word, ndarray): | ||
mean.append(weight * word) | ||
elif word in self.vocab: | ||
mean.append(weight * self.syn0norm[self.vocab[word].index]) | ||
all_words.add(self.vocab[word].index) | ||
else: | ||
raise KeyError("word '%s' not in vocabulary" % word) | ||
mean.append(weight * self.word_vec(word)) | ||
if word in self.vocab: | ||
all_words.add(self.vocab[word].index) | ||
if not mean: | ||
raise ValueError("cannot compute similarity with no input") | ||
mean = matutils.unitvec(array(mean).mean(axis=0)).astype(REAL) | ||
|
@@ -229,22 +237,14 @@ def most_similar_cosmul(self, positive=[], negative=[], topn=10): | |
# allow calls like most_similar_cosmul('dog'), as a shorthand for most_similar_cosmul(['dog']) | ||
positive = [positive] | ||
|
||
all_words = set() | ||
|
||
def word_vec(word): | ||
if isinstance(word, ndarray): | ||
return word | ||
elif word in self.vocab: | ||
all_words.add(self.vocab[word].index) | ||
return self.syn0norm[self.vocab[word].index] | ||
else: | ||
raise KeyError("word '%s' not in vocabulary" % word) | ||
|
||
positive = [word_vec(word) for word in positive] | ||
negative = [word_vec(word) for word in negative] | ||
positive = [self.word_vec(word, use_norm=True) for word in positive] | ||
negative = [self.word_vec(word, use_norm=True) for word in negative] | ||
if not positive: | ||
raise ValueError("cannot compute similarity with no input") | ||
|
||
all_words = set([self.vocab[word].index for word in positive+negative if word in self.vocab]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To remove the input words from the returned There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Eh, never mind, the review snippet showed me the code for Square brackets |
||
|
||
# equation (4) of Levy & Goldberg "Linguistic Regularities...", | ||
# with distances shifted to [0,1] per footnote (7) | ||
pos_dists = [((1 + dot(self.syn0norm, term)) / 2) for term in positive] | ||
|
@@ -314,7 +314,7 @@ def doesnt_match(self, words): | |
logger.debug("using words %s" % words) | ||
if not words: | ||
raise ValueError("cannot select a word from an empty list") | ||
vectors = vstack(self.syn0norm[self.vocab[word].index] for word in words).astype(REAL) | ||
vectors = vstack(self.word_vec(word) for word in words).astype(REAL) | ||
mean = matutils.unitvec(vectors.mean(axis=0)).astype(REAL) | ||
dists = dot(vectors, mean) | ||
return sorted(zip(dists, words))[0][1] | ||
|
@@ -344,9 +344,9 @@ def __getitem__(self, words): | |
""" | ||
if isinstance(words, string_types): | ||
# allow calls like trained_model['office'], as a shorthand for trained_model[['office']] | ||
return self.syn0[self.vocab[words].index] | ||
return self.word_vec(words) | ||
|
||
return vstack([self.syn0[self.vocab[word].index] for word in words]) | ||
return vstack([self.word_vec(word) for word in words]) | ||
|
||
def __contains__(self, word): | ||
return word in self.vocab | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -420,13 +420,13 @@ def __init__( | |
texts are longer than 10000 words, but the standard cython code truncates to that maximum.) | ||
|
||
""" | ||
|
||
if FAST_VERSION == -1: | ||
logger.warning('Slow version of {0} is being used'.format(__name__)) | ||
else: | ||
logger.debug('Fast version of {0} is being used'.format(__name__)) | ||
|
||
self.wv = KeyedVectors() # wv --> KeyedVectors | ||
self.initialize_word_vectors() | ||
self.sg = int(sg) | ||
self.cum_table = None # for negative sampling | ||
self.vector_size = int(size) | ||
|
@@ -460,6 +460,9 @@ def __init__( | |
self.build_vocab(sentences, trim_rule=trim_rule) | ||
self.train(sentences) | ||
|
||
def initialize_word_vectors(self): | ||
self.wv = KeyedVectors() # wv --> word vectors | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove comment, adds nothing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
def make_cum_table(self, power=0.75, domain=2**31 - 1): | ||
""" | ||
Create a cumulative-distribution table using stored vocabulary word counts for | ||
|
@@ -1617,4 +1620,4 @@ def __iter__(self): | |
model.accuracy(args.accuracy) | ||
|
||
logger.info("finished running %s", program) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
|
||
""" | ||
Python wrapper around word representation learning from FastText, a library for efficient learning | ||
of word representations and sentence classification [1]. | ||
|
||
This module allows training a word embedding from a training corpus with the additional ability | ||
to obtain word vectors for out-of-vocabulary words, using the fastText C implementation. | ||
|
||
The wrapped model can NOT be updated with new documents for online training -- use gensim's | ||
`Word2Vec` for that. | ||
|
||
Example: | ||
|
||
>>> model = gensim.models.wrappers.LdaMallet('/Users/kofola/fastText/fasttext', corpus_file='text8') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be gensim.models.wrappers.FastText(..) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, fixed. Thanks |
||
>>> print model[word] # prints vector for given words | ||
|
||
.. [1] https://github.com/facebookresearch/fastText#enriching-word-vectors-with-subword-information | ||
|
||
""" | ||
|
||
|
||
import logging | ||
import tempfile | ||
import os | ||
import struct | ||
|
||
import numpy as np | ||
|
||
from gensim import utils | ||
from gensim.models.keyedvectors import KeyedVectors | ||
from gensim.models.word2vec import Word2Vec | ||
|
||
from six import string_types | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class FastTextKeyedVectors(KeyedVectors): | ||
def word_vec(self, word, use_norm=False): | ||
if word in self.vocab: | ||
return super(FastTextKeyedVectors, self).word_vec(word, use_norm) | ||
else: | ||
word_vec = np.zeros(self.syn0_all.shape[1]) | ||
ngrams = FastText.compute_ngrams(word, self.min_n, self.max_n) | ||
for ngram in ngrams: | ||
if ngram in self.ngrams: | ||
word_vec += self.syn0_all[self.ngrams[ngram]] | ||
if word_vec.any(): | ||
return word_vec/len(ngrams) | ||
else: # No ngrams of the word are present in self.ngrams | ||
raise KeyError('all ngrams for word %s absent from model' % word) | ||
|
||
|
||
class FastText(Word2Vec): | ||
""" | ||
Class for word vector training using FastText. Communication between FastText and Python | ||
takes place by working with data files on disk and calling the FastText binary with | ||
subprocess.call(). | ||
Implements functionality similar to [fasttext.py](https://github.com/salestock/fastText.py), | ||
improving speed and scope of functionality like `most_similar`, `accuracy` by extracting vectors | ||
into numpy matrix. | ||
|
||
""" | ||
|
||
def initialize_word_vectors(self): | ||
self.wv = FastTextKeyedVectors() # wv --> word vectors | ||
|
||
@classmethod | ||
def train(cls, ft_path, corpus_file, output_file=None, model='cbow', size=100, alpha=0.025, window=5, min_count=5, | ||
loss='ns', sample=1e-3, negative=5, iter=5, min_n=3, max_n=6, sorted_vocab=1, threads=12): | ||
""" | ||
`ft_path` is the path to the FastText executable, e.g. `/home/kofola/fastText/fasttext`. | ||
|
||
`corpus_file` is the filename of the text file to be used for training the FastText model. | ||
Expects file to contain space-separated tokens in a single line | ||
|
||
`model` defines the training algorithm. By default, cbow is used. Accepted values are | ||
cbow, skipgram. | ||
|
||
`size` is the dimensionality of the feature vectors. | ||
|
||
`window` is the maximum distance between the current and predicted word within a sentence. | ||
|
||
`alpha` is the initial learning rate (will linearly drop to `min_alpha` as training progresses). | ||
|
||
`min_count` = ignore all words with total frequency lower than this. | ||
|
||
`loss` = defines training objective. Allowed values are `hs` (hierarchical softmax), | ||
`ns` (negative sampling) and `softmax`. Defaults to `ns` | ||
|
||
`sample` = threshold for configuring which higher-frequency words are randomly downsampled; | ||
default is 1e-3, useful range is (0, 1e-5). | ||
|
||
`negative` = the value for negative specifies how many "noise words" should be drawn | ||
(usually between 5-20). Default is 5. If set to 0, no negative samping is used. | ||
Only relevant when `loss` is set to `ns` | ||
|
||
`iter` = number of iterations (epochs) over the corpus. Default is 5. | ||
|
||
`min_n` = min length of char ngrams to be used for training word representations. Default is 1. | ||
|
||
`max_n` = max length of char ngrams to be used for training word representations. Set `max_n` to be | ||
greater than `min_n` to avoid char ngrams being used. Default is 5. | ||
|
||
`sorted_vocab` = if 1 (default), sort the vocabulary by descending frequency before | ||
assigning word indexes. | ||
|
||
""" | ||
ft_path = ft_path | ||
output_file = output_file or os.path.join(tempfile.gettempdir(), 'ft_model') | ||
ft_args = { | ||
'input': corpus_file, | ||
'output': output_file, | ||
'lr': alpha, | ||
'dim': size, | ||
'ws': window, | ||
'epoch': iter, | ||
'minCount': min_count, | ||
'neg': negative, | ||
'loss': loss, | ||
'minn': min_n, | ||
'maxn': max_n, | ||
'thread': threads, | ||
't': sample | ||
} | ||
cmd = [ft_path, model] | ||
for option, value in ft_args.items(): | ||
cmd.append("-%s" % option) | ||
cmd.append(str(value)) | ||
|
||
output = utils.check_output(args=cmd) | ||
model = cls.load_fasttext_format(output_file) | ||
return model | ||
|
||
@classmethod | ||
def load_fasttext_format(cls, model_file): | ||
model = cls.load_word2vec_format('%s.vec' % model_file) | ||
model.load_binary_data('%s.bin' % model_file) | ||
return model | ||
|
||
def load_binary_data(self, model_file): | ||
with open(model_file, 'rb') as f: | ||
self.load_model_params(f) | ||
self.load_dict(f) | ||
self.load_vectors(f) | ||
|
||
def load_model_params(self, f): | ||
(dim, ws, epoch, minCount, neg, _, loss, model, bucket, minn, maxn, _, t) = self.struct_unpack(f, '@12i1d') | ||
self.size = dim | ||
self.window = ws | ||
self.iter = epoch | ||
self.min_count = minCount | ||
self.negative = neg | ||
self.loss = loss | ||
self.sg = model == 'skipgram' | ||
self.bucket = bucket | ||
self.wv.min_n = minn | ||
self.wv.max_n = maxn | ||
self.sample = t | ||
|
||
def load_dict(self, f): | ||
(dim, nwords, _) = self.struct_unpack(f, '@3i') | ||
assert len(self.wv.vocab) == nwords, 'mismatch between vocab sizes' | ||
ntokens, = self.struct_unpack(f, '@q') | ||
for i in range(nwords): | ||
word = '' | ||
char, = self.struct_unpack(f, '@c') | ||
char = char.decode() | ||
while char != '\x00': | ||
word += char | ||
char, = self.struct_unpack(f, '@c') | ||
char = char.decode() | ||
count, _ = self.struct_unpack(f, '@ib') | ||
_ = self.struct_unpack(f, '@i') | ||
assert self.wv.vocab[word].index == i, 'mismatch between gensim word index and fastText word index' | ||
self.wv.vocab[word].count = count | ||
|
||
def load_vectors(self, f): | ||
num_vectors, dim = self.struct_unpack(f, '@2q') | ||
float_size = struct.calcsize('@f') | ||
if float_size == 4: | ||
dtype = np.dtype(np.float32) | ||
elif float_size == 8: | ||
dtype = np.dtype(np.float64) | ||
|
||
self.num_original_vectors = num_vectors | ||
self.wv.syn0_all = np.fromstring(f.read(num_vectors * dim * float_size), dtype=dtype) | ||
self.wv.syn0_all = self.wv.syn0_all.reshape((num_vectors, dim)) | ||
self.init_ngrams() | ||
|
||
def struct_unpack(self, f, fmt): | ||
num_bytes = struct.calcsize(fmt) | ||
return struct.unpack(fmt, f.read(num_bytes)) | ||
|
||
def init_ngrams(self): | ||
self.wv.ngrams = {} | ||
all_ngrams = [] | ||
for w, v in self.vocab.items(): | ||
all_ngrams += self.compute_ngrams(w, self.wv.min_n, self.wv.max_n) | ||
all_ngrams = set(all_ngrams) | ||
self.num_ngram_vectors = len(all_ngrams) | ||
ngram_indices = [] | ||
for i, ngram in enumerate(all_ngrams): | ||
ngram_hash = self.ft_hash(ngram) | ||
ngram_indices.append((len(self.wv.vocab) + ngram_hash) % self.bucket) | ||
self.wv.ngrams[ngram] = i | ||
self.wv.syn0_all = self.wv.syn0_all.take(ngram_indices, axis=0) | ||
|
||
@staticmethod | ||
def compute_ngrams(word, min_n, max_n): | ||
ngram_indices = [] | ||
BOW, EOW = ('<','>') | ||
extended_word = BOW + word + EOW | ||
ngrams = set() | ||
for i in range(len(extended_word) - min_n + 1): | ||
for j in range(min_n, max(len(extended_word) - max_n, max_n + 1)): | ||
ngrams.add(extended_word[i:i+j]) | ||
return ngrams | ||
|
||
@staticmethod | ||
def ft_hash(string): | ||
# Reproduces hash method used in fastText | ||
h = np.uint32(2166136261) | ||
for c in string: | ||
h = h ^ np.uint32(ord(c)) | ||
h = h * np.uint32(16777619) | ||
return h | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dead code test, can never reach here (above line would throw a KeyError).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
KeyError
has been removed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's still there, on line 66.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That line raises a
KeyError
in caseword in self.vocab
isFalse
. So in case it'sTrue
, line 115 would be executed.Also,
word_vec
has been overriden in theKeyedVectors
subclass forFastText
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, my point is -- isn't it always
True
? How could it beFalse
, when that would raise an exception at the line above? The test seems superfluous.But if subclasses can make
word_vec()
behave differently (not raise for missing words), then it makes sense. Not sure what the general contract forword_vec()
behaviour is.