Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Remove duplicates in WordSim353 when combining segments (#192)
Browse files Browse the repository at this point in the history
* Remove duplicates in WordSim353 when combining segments

* Fix path

* Update test
  • Loading branch information
leezu authored and szha committed Jul 3, 2018
1 parent 54e613f commit 8b01fa7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
4 changes: 2 additions & 2 deletions gluonnlp/data/word_embedding_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _get_data(self):
self.root,
'wordsim353_sim_rel/wordsim_similarity_goldstandard.txt'))

return [row for row in CorpusDataset(paths)]
return list({tuple(row) for row in CorpusDataset(paths)})


@register(segment=['full', 'dev', 'test'])
Expand Down Expand Up @@ -847,7 +847,7 @@ class BiggerAnalogyTestSet(WordAnalogyEvaluationDataset):

def __init__(self, category=None, form_analogy_pairs=True,
drop_alternative_solutions=True, root=os.path.join(
_get_home_dir(), 'datasets', 'simverb3500')):
_get_home_dir(), 'datasets', 'bigger_analogy')):
self.form_analogy_pairs = form_analogy_pairs
self.drop_alternative_solutions = drop_alternative_solutions
self.category = category
Expand Down
18 changes: 11 additions & 7 deletions tests/unittest/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,17 @@ def _assert_similarity_dataset(data):


@flaky(max_runs=2, min_passes=1)
def test_wordsim353():
for segment, length in (("all", 252 + 203), ("relatedness", 252),
("similarity", 203)):
data = nlp.data.WordSim353(segment=segment, root=os.path.join(
'tests', 'externaldata', 'wordsim353'))
assert len(data) == length, len(data)
_assert_similarity_dataset(data)
@pytest.mark.parametrize('segment,length', [('all', 352), ('relatedness', 252),
('similarity', 203)])
def test_wordsim353(segment, length):
# 'all' has length 352 as the original dataset contains the 'money/cash'
# pair twice with different similarity ratings, which was fixed by the
# http://alfonseca.org/eng/research/wordsim353.html version of the dataset
# that we are using.
data = nlp.data.WordSim353(segment=segment, root=os.path.join(
'tests', 'externaldata', 'wordsim353'))
assert len(data) == length, len(data)
_assert_similarity_dataset(data)


def test_men():
Expand Down
15 changes: 7 additions & 8 deletions tests/unittest/test_vocab_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,11 +936,11 @@ def test_word_embedding_similarity_evaluation_models(similarity_function):
similarity_function=similarity_function)
evaluator.initialize()

words1, words2 = mx.nd.array(words1), mx.nd.array(words2)
words1, words2 = nd.array(words1), nd.array(words2)
pred_similarity = evaluator(words1, words2)

sr = stats.spearmanr(pred_similarity.asnumpy(), np.array(scores))
assert np.isclose(0.6194264760578906, sr.correlation)
assert np.isclose(0.6076485693769645, sr.correlation)


@pytest.mark.parametrize(
Expand All @@ -957,7 +957,7 @@ def test_word_embedding_analogy_evaluation_models(analogy_function):

dataset_coded = [[vocab[d[0]], vocab[d[1]], vocab[d[2]], vocab[d[3]]]
for d in dataset]
dataset_coded_nd = mx.nd.array(dataset_coded)
dataset_coded_nd = nd.array(dataset_coded)

for k in [1, 3]:
for exclude_question_words in [True, False]:
Expand All @@ -974,17 +974,16 @@ def test_word_embedding_analogy_evaluation_models(analogy_function):

# If we don't exclude inputs most predictions should be wrong
words4 = dataset_coded_nd[:, 3]
accuracy = mx.nd.mean(pred_idxs[:, 0] == mx.nd.array(words4))
accuracy = nd.mean(pred_idxs[:, 0] == nd.array(words4))
accuracy = accuracy.asscalar()
if exclude_question_words == False:
if not exclude_question_words:
assert accuracy <= 0.1

# Instead the model would predict W3 most of the time
accuracy_w3 = mx.nd.mean(
pred_idxs[:, 0] == mx.nd.array(words3))
accuracy_w3 = nd.mean(pred_idxs[:, 0] == nd.array(words3))
assert accuracy_w3.asscalar() >= 0.89

elif exclude_question_words == True:
else:
# The wiki.simple vectors don't perform too good
assert accuracy >= 0.29

Expand Down

0 comments on commit 8b01fa7

Please sign in to comment.