diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index ac7ea090f..3e5035f17 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -628,9 +628,8 @@ def _get_random_pattern(self, pattern_distribution) -> str: get_random().choices( pattern_distribution["patterns"], pattern_distribution["weights"], - k=1, + k=pattern_distribution["length"], ) - * pattern_distribution["length"] ) return string_to_add diff --git a/tests/test_operators.py b/tests/test_operators.py index 2972c204d..dab2f1678 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -2016,7 +2016,7 @@ def test_augment_prefix_suffix_model_input(self): "\t\t " in output0 ), f"Trailing whitespaces wrongly removed, yielding {output0}, although 'remove_existing_whitespaces' is False," # weighted suffixes - suffixes_dict = {"Q": 2, "R": 2, "S": 2, "T": 8} + suffixes_dict = {"Q": 2, "R": 2, "S": 2, "T": 10} operator = AugmentPrefixSuffix( augment_model_input=True, suffixes=suffixes_dict, @@ -2027,11 +2027,11 @@ def test_augment_prefix_suffix_model_input(self): assert ( len(outputs) == 500 ), f"outputs length {len(outputs)} is different from inputs length, which is 500." - actual_suffixes = [output["source"][-8:] for output in outputs] + actual_suffixes = [output["source"][-2:] for output in outputs] counter = Counter(actual_suffixes) assert ( - counter["TTTTTTTT"] > 125 - ), f'In a population of size 500, suffix "TTTTTTTT" is expected to be more frequent than {counter["T"]}' + counter["TT"] > counter["SS"] + ), f'In a population of size 500 , suffix "TT" ({counter["TT"]}) is expected to be more frequent than "SS" {counter["SS"]}' # just for code coverage of Augmentor.process_value and Augmentor.process class JustToCoverProcessValueOfAugmentor(Augmentor):