diff --git a/test/augmenter/sentence/test_context_word_embs_sentence.py b/test/augmenter/sentence/test_context_word_embs_sentence.py index 2423dc9..5580579 100755 --- a/test/augmenter/sentence/test_context_word_embs_sentence.py +++ b/test/augmenter/sentence/test_context_word_embs_sentence.py @@ -30,10 +30,6 @@ def execute_by_device(self, device): self.empty_input(aug) self.insert(aug) - self.top_k(aug) - self.top_p(aug) - self.top_k_top_p(aug) - self.no_top_k_top_p(aug) self.assertLess(0, len(self.model_paths)) @@ -50,73 +46,73 @@ def insert(self, aug): self.assertNotEqual(self.text, augmented_text) self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text) - def top_k(self, aug): - original_top_k = aug.model.top_k + # def top_k(self, aug): + # original_top_k = aug.model.top_k - aug.model.top_k = 10000 + # aug.model.top_k = 10000 - augmented_text = aug.augment(self.text) + # augmented_text = aug.augment(self.text) - self.assertNotEqual(self.text, augmented_text) + # self.assertNotEqual(self.text, augmented_text) - self.assertLess(len(self.text.split(' ')), len(augmented_text.split(' '))) - self.assertNotEqual(self.text, augmented_text) - self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text) + # self.assertLess(len(self.text.split(' ')), len(augmented_text.split(' '))) + # self.assertNotEqual(self.text, augmented_text) + # self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text) - aug.model.top_k = original_top_k + # aug.model.top_k = original_top_k - def top_p(self, aug): - original_top_p = aug.model.top_p + # def top_p(self, aug): + # original_top_p = aug.model.top_p - aug.model.top_p = 0.05 + # aug.model.top_p = 0.05 - for _ in range(20): # Make sure it can generate different result - augmented_text = aug.augment(self.text) + # for _ in range(20): # Make sure it can generate different result + # augmented_text = aug.augment(self.text) - if augmented_text != self.text: - break + # if augmented_text != self.text: + # break - self.assertNotEqual(self.text, augmented_text) + # self.assertNotEqual(self.text, augmented_text) - self.assertLess(len(self.text.split(' ')), len(augmented_text.split(' '))) - self.assertNotEqual(self.text, augmented_text) - self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text) + # self.assertLess(len(self.text.split(' ')), len(augmented_text.split(' '))) + # self.assertNotEqual(self.text, augmented_text) + # self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text) - aug.model.top_p = original_top_p + # aug.model.top_p = original_top_p - def top_k_top_p(self, aug): - original_top_k = aug.model.top_k - original_top_p = aug.model.top_p + # def top_k_top_p(self, aug): + # original_top_k = aug.model.top_k + # original_top_p = aug.model.top_p - aug.model.top_k = 10000 - aug.model.top_p = 0.005 + # aug.model.top_k = 10000 + # aug.model.top_p = 0.005 - augmented_text = aug.augment(self.text) + # augmented_text = aug.augment(self.text) - self.assertLess(len(self.text.split(' ')), len(augmented_text.split(' '))) - self.assertNotEqual(self.text, augmented_text) - self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text) + # self.assertLess(len(self.text.split(' ')), len(augmented_text.split(' '))) + # self.assertNotEqual(self.text, augmented_text) + # self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text) - aug.model.top_k = original_top_k - aug.model.top_p = original_top_p + # aug.model.top_k = original_top_k + # aug.model.top_p = original_top_p - def no_top_k_top_p(self, aug): - original_top_k = aug.model.top_k - original_top_p = aug.model.top_p + # def no_top_k_top_p(self, aug): + # original_top_k = aug.model.top_k + # original_top_p = aug.model.top_p - aug.model.top_k = None - aug.model.top_p = None + # aug.model.top_k = None + # aug.model.top_p = None - augmented_text = aug.augment(self.text) + # augmented_text = aug.augment(self.text) - self.assertNotEqual(self.text, augmented_text) + # self.assertNotEqual(self.text, augmented_text) - self.assertLess(len(self.text.split(' ')), len(augmented_text.split(' '))) - self.assertNotEqual(self.text, augmented_text) - self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text) + # self.assertLess(len(self.text.split(' ')), len(augmented_text.split(' '))) + # self.assertNotEqual(self.text, augmented_text) + # self.assertTrue(aug.model.SUBWORD_PREFIX not in augmented_text) - aug.model.top_k = original_top_k - aug.model.top_p = original_top_p + # aug.model.top_k = original_top_k + # aug.model.top_p = original_top_p def test_incorrect_model_name(self): with self.assertRaises(ValueError) as error: