diff --git a/src/unitxt/formats.py b/src/unitxt/formats.py index 24770024f..2b83422d2 100644 --- a/src/unitxt/formats.py +++ b/src/unitxt/formats.py @@ -55,7 +55,22 @@ def apply_capital_new_line_notation(text: str) -> str: return re.sub(r"[\n(\\N)]*(\\N)+", r"\n", text) -class SystemFormat(Format): +class BaseFormat(Format): + demos_field: str = "demos" + + @staticmethod + def _retrieve_field_and_pop_from_instance(instance, field_name) -> str: + if field_name is not None and field_name in instance: + field_value = instance[field_name] + instance.pop(field_name) + assert ( + field_value is not None + ), f"Value in field '{field_name}' should not be none. Received instance: {instance}" + return field_value + return "" + + +class SystemFormat(BaseFormat): r"""Generates the whole input to the model, from constant strings that are given as args, and from values found in specified fields of the instance. Important: formats can use '\N' notations that means new-line if no new-line before and no empty string before. @@ -113,50 +128,32 @@ class SystemFormat(Format): """ - demos_field: str = "demos" demo_format: str = "{source}\\N{target_prefix}{target}\n\n" # example: "User: {source}\nAgent: {target}\n\n" model_input_format: str = ( "{system_prompt}\\N{instruction}\\N{demos}{source}\\N{target_prefix}" ) format_args: Dict[str, str] = OptionalField(default_factory=dict) - @staticmethod - def _retrieve_field_and_assert_not_none(instance, field_name) -> str: - if field_name is not None and field_name in instance: - field_value = instance[field_name] - assert ( - field_value is not None - ), f"Value in field '{field_name}' should not be none. Received instance: {instance}" - return field_value - return "" - def process( self, instance: Dict[str, Any], stream_name: Optional[str] = None ) -> Dict[str, Any]: assert ( "source" in instance ), f"field 'source' is expected to be in the input instance. Received instance: {instance}" - source = self._retrieve_field_and_assert_not_none( + source = self._retrieve_field_and_pop_from_instance( instance=instance, field_name="source" ) - instruction = self._retrieve_field_and_assert_not_none( + instruction = self._retrieve_field_and_pop_from_instance( instance=instance, field_name="instruction" ) - target_prefix = self._retrieve_field_and_assert_not_none( + target_prefix = self._retrieve_field_and_pop_from_instance( instance=instance, field_name="target_prefix" ) - system_prompt = self._retrieve_field_and_assert_not_none( + system_prompt = self._retrieve_field_and_pop_from_instance( instance=instance, field_name="system_prompt" ) - if "target_prefix" in instance: - instance.pop("target_prefix") - if "instruction" in instance: - instance.pop("instruction") - if "system_prompt" in instance: - instance.pop("system_prompt") - demo_instances = [] if self.demos_field is not None and self.demos_field in instance: demos = instance[self.demos_field] @@ -187,3 +184,92 @@ def process( output = apply_capital_new_line_notation(output) instance["source"] = output return instance + + +class HFSystemFormat(BaseFormat): + r"""Formats the complete input for the model using the Hugginface chat template of a given model. + + HFSystemFormat expects the input instance to contain: + 1. A field named "system_prompt" whose value is a string (potentially empty) that delivers a task independent opening text. + 2. A field named "source" whose value is a string verbalizing the original values in the instance (as read + from the source dataset), in the context of the underlying task. + 3. A field named "instruction" that contains a (non-None) string. + 4. A field named with the value in arg 'demos_field', containing a list of dicts, each dict with fields "source" + and "target", representing a single demo. + 5. A field named "target_prefx" that contains a string to prefix the target in both each demo, and to end the whole generated prompt + + SystemFormat formats the above fields into a single string to be inputted to the model. This string overwrites + field "source" of the instance. + + Example: + HFSystemFormat(model_name="HuggingFaceH4/zephyr-7b-beta") + + Uses the template defined the in tokenizer_config.json of the model: + + "chat_template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + + See more details in https://huggingface.co/docs/transformers/main/en/chat_templating + + """ + + model_name: str + + def process( + self, instance: Dict[str, Any], stream_name: Optional[str] = None + ) -> Dict[str, Any]: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + assert ( + "source" in instance + ), f"field 'source' is expected to be in the input instance. Received instance: {instance}" + + source = self._retrieve_field_and_pop_from_instance( + instance=instance, field_name="source" + ) + + instruction = self._retrieve_field_and_pop_from_instance( + instance=instance, field_name="instruction" + ) + target_prefix = self._retrieve_field_and_pop_from_instance( + instance=instance, field_name="target_prefix" + ) + system_prompt = self._retrieve_field_and_pop_from_instance( + instance=instance, field_name="system_prompt" + ) + + messages = [ + { + "role": "system", + "content": system_prompt + + ("\n" if system_prompt != "" else "") + + instruction, + }, + ] + demo_instances = [] + if self.demos_field is not None and self.demos_field in instance: + demos = instance[self.demos_field] + assert ( + demos is not None and isoftype(demos, List[Dict[str, Any]]) + ), f"A list of dict-s is expected in field '{self.demos_field}'. Received instance: {instance}" + demo_instances = demos + instance.pop(self.demos_field) + + for demo_instance in demo_instances: + messages.extend( + [ + {"role": "user", "content": demo_instance["source"]}, + { + "role": "assistant", + "content": target_prefix + demo_instance["target"], + }, + ] + ) + messages.extend([{"role": "user", "content": source}]) + tokenized_chat = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + instance["source"] = tokenized_chat + target_prefix + return instance diff --git a/tests/library/test_formats.py b/tests/library/test_formats.py index 2120cc0df..8e339dd76 100644 --- a/tests/library/test_formats.py +++ b/tests/library/test_formats.py @@ -1,4 +1,4 @@ -from unitxt.formats import SystemFormat +from unitxt.formats import HFSystemFormat, SystemFormat from unitxt.test_utils.operators import ( check_operator, ) @@ -7,6 +7,58 @@ class TestFormats(UnitxtTestCase): + def test_hf_system_format(self): + instruction = "solve the math exercises" + + demo_instances = [ + {"source": "1+2", "target": "3", "instruction": instruction, "inputs": {}}, + {"source": "4-2", "target": "2", "instruction": instruction, "inputs": {}}, + ] + + inputs = [ + { + "source": "1+1", + "target": "2", + "instruction": instruction, + "demos": demo_instances, + "inputs": {}, + "target_prefix": "The answer is ", + "system_prompt": "You are a smart assistant.", + }, + { + "source": "3+2", + "target": "5", + "instruction": instruction, + "demos": demo_instances, + "inputs": {}, + "target_prefix": "The answer is ", + "system_prompt": "You are a smart assistant.", + }, + ] + + # imitating iclformat's add_instruction_after_demos=True, instruction is not "", and target_prefix ="" + system_format = HFSystemFormat(model_name="HuggingFaceH4/zephyr-7b-beta") + + targets = [ + { + "target": "2", + "inputs": {}, + "source": "<|system|>\nYou are a smart assistant.\nsolve the math exercises\n<|user|>\n1+2\n<|assistant|>\nThe answer is 3\n<|user|>\n4-2\n<|assistant|>\nThe answer is 2\n<|user|>\n1+1\n<|assistant|>\nThe answer is ", + }, + { + "target": "5", + "inputs": {}, + "source": "<|system|>\nYou are a smart assistant.\nsolve the math exercises\n<|user|>\n1+2\n<|assistant|>\nThe answer is 3\n<|user|>\n4-2\n<|assistant|>\nThe answer is 2\n<|user|>\n3+2\n<|assistant|>\nThe answer is ", + }, + ] + + check_operator( + operator=system_format, + inputs=inputs, + targets=targets, + tester=self, + ) + def test_system_format(self): instruction = "solve the math exercises"