Skip to content

Commit

Permalink
Fix suffix prefix not return same prefix/suffix repeated
Browse files Browse the repository at this point in the history
Signed-off-by: Yoav Katz <katz@il.ibm.com>
  • Loading branch information
yoavkatz committed Dec 18, 2023
1 parent d9ad7dd commit 14456ef
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,8 @@ def _get_random_pattern(self, pattern_distribution) -> str:
get_random().choices(
pattern_distribution["patterns"],
pattern_distribution["weights"],
k=1,
k=pattern_distribution["length"],
)
* pattern_distribution["length"]
)
return string_to_add

Expand Down
8 changes: 4 additions & 4 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,7 +2016,7 @@ def test_augment_prefix_suffix_model_input(self):
"\t\t " in output0
), f"Trailing whitespaces wrongly removed, yielding {output0}, although 'remove_existing_whitespaces' is False,"
# weighted suffixes
suffixes_dict = {"Q": 2, "R": 2, "S": 2, "T": 8}
suffixes_dict = {"Q": 2, "R": 2, "S": 2, "T": 10}
operator = AugmentPrefixSuffix(
augment_model_input=True,
suffixes=suffixes_dict,
Expand All @@ -2027,11 +2027,11 @@ def test_augment_prefix_suffix_model_input(self):
assert (
len(outputs) == 500
), f"outputs length {len(outputs)} is different from inputs length, which is 500."
actual_suffixes = [output["source"][-8:] for output in outputs]
actual_suffixes = [output["source"][-2:] for output in outputs]
counter = Counter(actual_suffixes)
assert (
counter["TTTTTTTT"] > 125
), f'In a population of size 500, suffix "TTTTTTTT" is expected to be more frequent than {counter["T"]}'
counter["TT"] > counter["SS"]
), f'In a population of size 500 , suffix "TT" ({counter["TT"]}) is expected to be more frequent than "SS" {counter["SS"]}'

# just for code coverage of Augmentor.process_value and Augmentor.process
class JustToCoverProcessValueOfAugmentor(Augmentor):
Expand Down

0 comments on commit 14456ef

Please sign in to comment.