Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added function "predict_output_word" to predict the output word given the context words. Fixes issue #863. #1209

Merged
merged 8 commits into from
Mar 20, 2017
21 changes: 21 additions & 0 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,27 @@ def similarity(self, w1, w2):
def n_similarity(self, ws1, ws2):
return self.wv.n_similarity(ws1, ws2)

def predict_output_word(self, context_words_list, topn=10):
#verify that required parameters have not been discarded
if not hasattr(self.wv, 'syn0') or not hasattr(self, 'syn1neg'):
raise RuntimeError("Parameters required for predicting the output words not found.")

word_vocabs = [self.wv.vocab[w] for w in context_words_list if w in self.wv.vocab]

word2_indices = []
for pos, word in enumerate(word_vocabs):
word2_indices.append(word.index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use list comprehension


l1 = np_sum(self.wv.syn0[word2_indices], axis=0)
if word2_indices and self.cbow_mean:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if word_vocabs is empty, then return None with a warning

l1 /= len(word2_indices)

if self.negative :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please raise exception

if not self.negative:
            raise RuntimeError("We have currently only implemented for negative sampling")
``

prob_values = exp(dot(l1, self.syn1.T)) # propagate hidden -> output and take softmax to get probabilities
prob_values /= sum(prob_values)
top_indices = matutils.argsort(prob_values, topn=topn, reverse=True)
return [(self.wv.index2word[index1], prob_values[index1]) for index1 in top_indices] #returning the most probable output words with their probabilities

def init_sims(self, replace=False):
"""
init_sims() resides in KeyedVectors because it deals with syn0 mainly, but because syn1 is not an attribute
Expand Down