From bac4ff17ff1a7cba22b8fdffe1750e273bdaa438 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 02:33:53 +0000 Subject: [PATCH 01/19] strengthened chat formatting validation --- llmfoundry/data/finetuning/tasks.py | 43 ++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 92251df8a2..c619e7b98e 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -54,6 +54,11 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: _ALLOWED_RESPONSE_KEYS = {'response', 'completion'} _ALLOWED_PROMPT_KEYS = {'prompt'} +_ALLOWED_MESSAGES_KEYS = {'messages'} +_ALLOWED_ROLE_KEYS = {'role'} +_ALLOWED_CONTENT_KEYS = {'content'} +_ALLOWED_ROLES = {'user', 'assistant', 'system'} +_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'} DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath( os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, '.downloaded_finetuning')) @@ -101,6 +106,41 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 +def _get_role_key(message: Dict[str, str]) -> str: + role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys()) + assert len(role_keys) == 1 + role_key = list(role_keys)[0] + return role_key + + +def _get_content_key(message: Dict[str, str]) -> str: + content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys()) + assert len(content_keys) == 1 + content_key = list(content_keys)[0] + return content_key + + +def _get_message_key(example: ChatFormattedDict): + assert len(example.keys()) == 1 + message_key = example.keys()[0] + assert message_key in _ALLOWED_MESSAGES_KEYS + return message_key + + +def _validate_chat_formatted_example(example: ChatFormattedDict): + messages = example[example.keys()[0]] + for message in messages: + assert len(message.keys()) == 2 + role_key, _ = _get_role_key(message), _get_content_key(message) + assert message[role_key] in _ALLOWED_ROLES + + assert len(messages) > 1 + last_message = messages[-1] + role_key = _get_role_key(last_message) + last_role = last_message[role_key] + assert last_role in _ALLOWED_LAST_MESSAGE_ROLES + + def _slice_chat_formatted_example( example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]: @@ -117,7 +157,8 @@ def _slice_chat_formatted_example( ValueError: If the chat example has less than two messages or if the last message is not from the assistant. KeyError: If a message does not have a role or content. """ - messages = example['messages'] + _validate_chat_formatted_example(example) + messages = example[_get_message_key(example)] if len(messages) < 2: raise ValueError( From 2067c6a0e9031879f7c1d7557888e8652337b001 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 14:24:41 +0000 Subject: [PATCH 02/19] fix types --- llmfoundry/data/finetuning/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index c619e7b98e..888526d994 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -122,13 +122,13 @@ def _get_content_key(message: Dict[str, str]) -> str: def _get_message_key(example: ChatFormattedDict): assert len(example.keys()) == 1 - message_key = example.keys()[0] + message_key = list(example.keys())[0] assert message_key in _ALLOWED_MESSAGES_KEYS return message_key def _validate_chat_formatted_example(example: ChatFormattedDict): - messages = example[example.keys()[0]] + messages = example[_get_message_key(example)] for message in messages: assert len(message.keys()) == 2 role_key, _ = _get_role_key(message), _get_content_key(message) From 0c68280f004d5f32975dcb1515974c1d56e32a3b Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 14:44:06 +0000 Subject: [PATCH 03/19] made assert messages more descriptive --- llmfoundry/data/finetuning/tasks.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 888526d994..5f0bf832be 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -108,37 +108,42 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: def _get_role_key(message: Dict[str, str]) -> str: role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys()) - assert len(role_keys) == 1 + assert len( + role_keys) == 1, f'Expected 1 role key, but found {len(role_keys)}' role_key = list(role_keys)[0] return role_key def _get_content_key(message: Dict[str, str]) -> str: content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys()) - assert len(content_keys) == 1 + assert len(content_keys + ) == 1, f'Expected 1 content key, but found {len(content_keys)}' content_key = list(content_keys)[0] return content_key def _get_message_key(example: ChatFormattedDict): - assert len(example.keys()) == 1 + assert len(example.keys() + ) == 1, f'Expected 1 message key, but found {len(example.keys())}' message_key = list(example.keys())[0] - assert message_key in _ALLOWED_MESSAGES_KEYS + assert message_key in _ALLOWED_MESSAGES_KEYS, f'Invalid message key: {message_key}' return message_key def _validate_chat_formatted_example(example: ChatFormattedDict): messages = example[_get_message_key(example)] for message in messages: - assert len(message.keys()) == 2 + assert len(message.keys( + )) == 2, f'Expected 2 keys in message, but found {len(message.keys())}' role_key, _ = _get_role_key(message), _get_content_key(message) - assert message[role_key] in _ALLOWED_ROLES + assert message[ + role_key] in _ALLOWED_ROLES, f'Invalid role: {message[role_key]}' - assert len(messages) > 1 + assert len(messages) > 1, 'Chat example must have at least two messages' last_message = messages[-1] role_key = _get_role_key(last_message) last_role = last_message[role_key] - assert last_role in _ALLOWED_LAST_MESSAGE_ROLES + assert last_role in _ALLOWED_LAST_MESSAGE_ROLES, f'Invalid last message role: {last_role}' def _slice_chat_formatted_example( From b2453b84f4620630b36751eafaf8307deda2a511 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 19:33:41 +0000 Subject: [PATCH 04/19] used raise instead of assert, added type checks --- llmfoundry/data/finetuning/tasks.py | 41 ++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index fd03598ddf..e75355398a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -108,43 +108,58 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: def _get_role_key(message: Dict[str, str]) -> str: + if not isinstance(message, dict): + raise TypeError( + f'Expected message to be a dict, but found {type(message)}') role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys()) - assert len( - role_keys) == 1, f'Expected 1 role key, but found {len(role_keys)}' + if len(role_keys) != 1: + raise ValueError(f'Expected 1 role key, but found {len(role_keys)}') role_key = list(role_keys)[0] return role_key def _get_content_key(message: Dict[str, str]) -> str: + if not isinstance(message, dict): + raise TypeError( + f'Expected message to be a dict, but found {type(message)}') content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys()) - assert len(content_keys - ) == 1, f'Expected 1 content key, but found {len(content_keys)}' + if len(content_keys) != 1: + raise ValueError( + f'Expected 1 content key, but found {len(content_keys)}') content_key = list(content_keys)[0] return content_key def _get_message_key(example: ChatFormattedDict): - assert len(example.keys() - ) == 1, f'Expected 1 message key, but found {len(example.keys())}' + if len(example.keys()) != 1: + raise ValueError( + f'Expected 1 message key, but found {len(example.keys())}') message_key = list(example.keys())[0] - assert message_key in _ALLOWED_MESSAGES_KEYS, f'Invalid message key: {message_key}' + if message_key not in _ALLOWED_MESSAGES_KEYS: + raise ValueError(f'Invalid message key: {message_key}') return message_key def _validate_chat_formatted_example(example: ChatFormattedDict): + if not isinstance(example, dict): + raise TypeError( + f'Expected example to be a dict, but found {type(example)}') messages = example[_get_message_key(example)] for message in messages: - assert len(message.keys( - )) == 2, f'Expected 2 keys in message, but found {len(message.keys())}' + if len(message.keys()) != 2: + raise ValueError( + f'Expected 2 keys in message, but found {len(message.keys())}') role_key, _ = _get_role_key(message), _get_content_key(message) - assert message[ - role_key] in _ALLOWED_ROLES, f'Invalid role: {message[role_key]}' + if message[role_key] not in _ALLOWED_ROLES: + raise ValueError(f'Invalid role: {message[role_key]}') - assert len(messages) > 1, 'Chat example must have at least two messages' + if len(messages) <= 1: + raise ValueError('Chat example must have at least two messages') last_message = messages[-1] role_key = _get_role_key(last_message) last_role = last_message[role_key] - assert last_role in _ALLOWED_LAST_MESSAGE_ROLES, f'Invalid last message role: {last_role}' + if last_role not in _ALLOWED_LAST_MESSAGE_ROLES: + raise ValueError(f'Invalid last message role: {last_role}') def _slice_chat_formatted_example( From 51be88fd12070b97f1a84f8ae5c47e3b101ea9cc Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 19:38:15 +0000 Subject: [PATCH 05/19] added list type check --- llmfoundry/data/finetuning/tasks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index e75355398a..5d548a41a4 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -145,6 +145,9 @@ def _validate_chat_formatted_example(example: ChatFormattedDict): raise TypeError( f'Expected example to be a dict, but found {type(example)}') messages = example[_get_message_key(example)] + if not isinstance(messages, list): + raise TypeError( + f'Expected messages to be a list, but found {type(messages)}') for message in messages: if len(message.keys()) != 2: raise ValueError( From 69798033cdd07c1c03b0a10f55076deafb0776f9 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 19:40:08 +0000 Subject: [PATCH 06/19] type error if no string content --- llmfoundry/data/finetuning/tasks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 5d548a41a4..44280f7aae 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -152,9 +152,14 @@ def _validate_chat_formatted_example(example: ChatFormattedDict): if len(message.keys()) != 2: raise ValueError( f'Expected 2 keys in message, but found {len(message.keys())}') - role_key, _ = _get_role_key(message), _get_content_key(message) + role_key, content_key = _get_role_key(message), _get_content_key( + message) if message[role_key] not in _ALLOWED_ROLES: raise ValueError(f'Invalid role: {message[role_key]}') + if not isinstance(message[content_key], str): + raise TypeError( + f'Expected content to be a string, but found {type(message[content_key])}' + ) if len(messages) <= 1: raise ValueError('Chat example must have at least two messages') From 5d219699673bb07d501545186a76b7aa71171371 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 19:43:26 +0000 Subject: [PATCH 07/19] add test case for new validation --- tests/data/test_template_tokenization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 829d1ebbc0..5491b94521 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -42,8 +42,10 @@ def test_tokenize_chat_example_malformed(): 'content': 'user message not followed by an assistant label' }] } + wrong_type = {'messages': 'this is not a list of messages'} malformed_chat_examples = [ - too_few_messages, no_content, ends_with_user_role, no_assistant_message + too_few_messages, no_content, ends_with_user_role, no_assistant_message, + wrong_type ] my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) for example in malformed_chat_examples: From 6e6a206ea3d3965f1d13fda06a011f40cfa0651c Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 20:23:01 +0000 Subject: [PATCH 08/19] relaxed type constraints to interface minimum --- llmfoundry/data/finetuning/tasks.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 44280f7aae..1c9393d62e 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -84,7 +84,11 @@ def _get_example_type(example: Example) -> ExampleType: Raises: KeyError: If the example type is unknown. """ - if 'messages' in example: + if not hasattr(example, 'keys'): + raise TypeError( + f'Expected example to have dict-like, but found {type(example)}') + if any(allowed_message_key in example + for allowed_message_key in _ALLOWED_MESSAGES_KEYS): return 'chat' elif any([ pr in example @@ -108,9 +112,9 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: def _get_role_key(message: Dict[str, str]) -> str: - if not isinstance(message, dict): + if not hasattr(message, '__getitem__') or not hasattr(message, 'keys'): raise TypeError( - f'Expected message to be a dict, but found {type(message)}') + f'Expected message to be dict-like, but found {type(message)}') role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys()) if len(role_keys) != 1: raise ValueError(f'Expected 1 role key, but found {len(role_keys)}') @@ -119,9 +123,9 @@ def _get_role_key(message: Dict[str, str]) -> str: def _get_content_key(message: Dict[str, str]) -> str: - if not isinstance(message, dict): + if not hasattr(message, '__getitem__') or not hasattr(message, 'keys'): raise TypeError( - f'Expected message to be a dict, but found {type(message)}') + f'Expected message to be dict-like, but found {type(message)}') content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys()) if len(content_keys) != 1: raise ValueError( @@ -131,6 +135,9 @@ def _get_content_key(message: Dict[str, str]) -> str: def _get_message_key(example: ChatFormattedDict): + if not hasattr(example, 'keys'): + raise TypeError( + f'Expected example to have keys(), but found {type(example)}') if len(example.keys()) != 1: raise ValueError( f'Expected 1 message key, but found {len(example.keys())}') @@ -141,19 +148,19 @@ def _get_message_key(example: ChatFormattedDict): def _validate_chat_formatted_example(example: ChatFormattedDict): - if not isinstance(example, dict): + if not hasattr(example, '__getitem__') or not hasattr(example, 'keys'): raise TypeError( - f'Expected example to be a dict, but found {type(example)}') + f'Expected example to be dict-like, but found {type(example)}') messages = example[_get_message_key(example)] - if not isinstance(messages, list): + if not hasattr(messages, '__iter__'): raise TypeError( - f'Expected messages to be a list, but found {type(messages)}') + f'Expected messages to be an iterator, but found {type(messages)}') for message in messages: + role_key, content_key = _get_role_key(message), _get_content_key( + message) if len(message.keys()) != 2: raise ValueError( f'Expected 2 keys in message, but found {len(message.keys())}') - role_key, content_key = _get_role_key(message), _get_content_key( - message) if message[role_key] not in _ALLOWED_ROLES: raise ValueError(f'Invalid role: {message[role_key]}') if not isinstance(message[content_key], str): From 85710ff1d9b4deaa7b0cb3db092892906be1d61f Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 20:50:41 +0000 Subject: [PATCH 09/19] use Mapping and Iterable --- llmfoundry/data/finetuning/tasks.py | 32 +++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 1c9393d62e..e22702786c 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,6 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings +from collections.abc import Iterable, Mapping from functools import partial from pathlib import Path from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union, @@ -84,7 +85,7 @@ def _get_example_type(example: Example) -> ExampleType: Raises: KeyError: If the example type is unknown. """ - if not hasattr(example, 'keys'): + if not isinstance(example, Mapping): raise TypeError( f'Expected example to have dict-like, but found {type(example)}') if any(allowed_message_key in example @@ -111,10 +112,10 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 -def _get_role_key(message: Dict[str, str]) -> str: - if not hasattr(message, '__getitem__') or not hasattr(message, 'keys'): +def _get_role_key(message: Mapping[str, str]) -> str: + if not isinstance(message, Mapping): raise TypeError( - f'Expected message to be dict-like, but found {type(message)}') + f'Expected message to be a mapping, but found {type(message)}') role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys()) if len(role_keys) != 1: raise ValueError(f'Expected 1 role key, but found {len(role_keys)}') @@ -122,10 +123,10 @@ def _get_role_key(message: Dict[str, str]) -> str: return role_key -def _get_content_key(message: Dict[str, str]) -> str: - if not hasattr(message, '__getitem__') or not hasattr(message, 'keys'): +def _get_content_key(message: Mapping[str, str]) -> str: + if not isinstance(message, Mapping): raise TypeError( - f'Expected message to be dict-like, but found {type(message)}') + f'Expected message to be a mapping, but found {type(message)}') content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys()) if len(content_keys) != 1: raise ValueError( @@ -134,10 +135,10 @@ def _get_content_key(message: Dict[str, str]) -> str: return content_key -def _get_message_key(example: ChatFormattedDict): - if not hasattr(example, 'keys'): +def _get_message_key(example: Mapping[str, List[Mapping[str, str]]]): + if not isinstance(example, Mapping): raise TypeError( - f'Expected example to have keys(), but found {type(example)}') + f'Expected example to be a mapping, but found {type(example)}') if len(example.keys()) != 1: raise ValueError( f'Expected 1 message key, but found {len(example.keys())}') @@ -147,14 +148,15 @@ def _get_message_key(example: ChatFormattedDict): return message_key -def _validate_chat_formatted_example(example: ChatFormattedDict): - if not hasattr(example, '__getitem__') or not hasattr(example, 'keys'): +def _validate_chat_formatted_example(example: Mapping[str, List[Mapping[str, + str]]]): + if not isinstance(example, Mapping): raise TypeError( - f'Expected example to be dict-like, but found {type(example)}') + f'Expected example to be a mapping, but found {type(example)}') messages = example[_get_message_key(example)] - if not hasattr(messages, '__iter__'): + if not isinstance(messages, Iterable): raise TypeError( - f'Expected messages to be an iterator, but found {type(messages)}') + f'Expected messages to be an iterable, but found {type(messages)}') for message in messages: role_key, content_key = _get_role_key(message), _get_content_key( message) From aa975a9186e259f2714e08793bd9a4fdba107f8f Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 21:06:46 +0000 Subject: [PATCH 10/19] fix mapping in type aliases too --- llmfoundry/data/finetuning/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index e22702786c..99680bb354 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -66,8 +66,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: '.downloaded_finetuning')) SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] -PromptResponseDict = Dict[str, str] -ChatFormattedDict = Dict[str, List[Dict[str, str]]] +PromptResponseDict = Mapping[str, str] +ChatFormattedDict = Mapping[str, Iterable[Mapping[str, str]]] Example = Union[PromptResponseDict, ChatFormattedDict] ExampleType = Literal['prompt_response', 'chat'] TokenizedExample = Dict[str, List[int]] From 24415cf20febc4beba7d32f4181e93d6d16d9a28 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 21:17:29 +0000 Subject: [PATCH 11/19] iterable -> sequence --- llmfoundry/data/finetuning/tasks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 99680bb354..ac658ca1a7 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,7 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings -from collections.abc import Iterable, Mapping +from collections.abc import Mapping, Sequence from functools import partial from pathlib import Path from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union, @@ -67,7 +67,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] PromptResponseDict = Mapping[str, str] -ChatFormattedDict = Mapping[str, Iterable[Mapping[str, str]]] +ChatFormattedDict = Mapping[str, Sequence[Mapping[str, str]]] Example = Union[PromptResponseDict, ChatFormattedDict] ExampleType = Literal['prompt_response', 'chat'] TokenizedExample = Dict[str, List[int]] @@ -154,7 +154,7 @@ def _validate_chat_formatted_example(example: Mapping[str, List[Mapping[str, raise TypeError( f'Expected example to be a mapping, but found {type(example)}') messages = example[_get_message_key(example)] - if not isinstance(messages, Iterable): + if not isinstance(messages, Sequence): raise TypeError( f'Expected messages to be an iterable, but found {type(messages)}') for message in messages: From eaa7e8a49668de22dcbc55c801405055cbcbb00a Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 21:59:27 +0000 Subject: [PATCH 12/19] sequence -> list --- llmfoundry/data/finetuning/tasks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index ac658ca1a7..9e9cf45fc2 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,7 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from functools import partial from pathlib import Path from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union, @@ -67,7 +67,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] PromptResponseDict = Mapping[str, str] -ChatFormattedDict = Mapping[str, Sequence[Mapping[str, str]]] +ChatFormattedDict = Mapping[str, List[Mapping[str, str]]] Example = Union[PromptResponseDict, ChatFormattedDict] ExampleType = Literal['prompt_response', 'chat'] TokenizedExample = Dict[str, List[int]] @@ -154,7 +154,7 @@ def _validate_chat_formatted_example(example: Mapping[str, List[Mapping[str, raise TypeError( f'Expected example to be a mapping, but found {type(example)}') messages = example[_get_message_key(example)] - if not isinstance(messages, Sequence): + if not isinstance(messages, List): raise TypeError( f'Expected messages to be an iterable, but found {type(messages)}') for message in messages: From b179a9e8faf9f40b177e973ea4ac694729108219 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 22:24:35 +0000 Subject: [PATCH 13/19] Mapping -> Dict --- llmfoundry/data/finetuning/tasks.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 9e9cf45fc2..55002630a6 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,7 +35,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings -from collections.abc import Mapping from functools import partial from pathlib import Path from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union, @@ -66,8 +65,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: '.downloaded_finetuning')) SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] -PromptResponseDict = Mapping[str, str] -ChatFormattedDict = Mapping[str, List[Mapping[str, str]]] +PromptResponseDict = Dict[str, str] +ChatFormattedDict = Dict[str, List[Dict[str, str]]] Example = Union[PromptResponseDict, ChatFormattedDict] ExampleType = Literal['prompt_response', 'chat'] TokenizedExample = Dict[str, List[int]] @@ -85,7 +84,7 @@ def _get_example_type(example: Example) -> ExampleType: Raises: KeyError: If the example type is unknown. """ - if not isinstance(example, Mapping): + if not isinstance(example, Dict): raise TypeError( f'Expected example to have dict-like, but found {type(example)}') if any(allowed_message_key in example @@ -112,8 +111,8 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 -def _get_role_key(message: Mapping[str, str]) -> str: - if not isinstance(message, Mapping): +def _get_role_key(message: Dict[str, str]) -> str: + if not isinstance(message, Dict): raise TypeError( f'Expected message to be a mapping, but found {type(message)}') role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys()) @@ -123,8 +122,8 @@ def _get_role_key(message: Mapping[str, str]) -> str: return role_key -def _get_content_key(message: Mapping[str, str]) -> str: - if not isinstance(message, Mapping): +def _get_content_key(message: Dict[str, str]) -> str: + if not isinstance(message, Dict): raise TypeError( f'Expected message to be a mapping, but found {type(message)}') content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys()) @@ -135,8 +134,8 @@ def _get_content_key(message: Mapping[str, str]) -> str: return content_key -def _get_message_key(example: Mapping[str, List[Mapping[str, str]]]): - if not isinstance(example, Mapping): +def _get_message_key(example: Dict[str, List[Dict[str, str]]]): + if not isinstance(example, Dict): raise TypeError( f'Expected example to be a mapping, but found {type(example)}') if len(example.keys()) != 1: @@ -148,9 +147,8 @@ def _get_message_key(example: Mapping[str, List[Mapping[str, str]]]): return message_key -def _validate_chat_formatted_example(example: Mapping[str, List[Mapping[str, - str]]]): - if not isinstance(example, Mapping): +def _validate_chat_formatted_example(example: ChatFormattedDict): + if not isinstance(example, Dict): raise TypeError( f'Expected example to be a mapping, but found {type(example)}') messages = example[_get_message_key(example)] From ed91ebdd919b4fb30cc039d7e212494523a9aa14 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 22:29:31 +0000 Subject: [PATCH 14/19] use mapping again --- llmfoundry/data/finetuning/tasks.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 55002630a6..235b3f016f 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,6 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings +from collections.abc import Mapping from functools import partial from pathlib import Path from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union, @@ -112,7 +113,7 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: def _get_role_key(message: Dict[str, str]) -> str: - if not isinstance(message, Dict): + if not isinstance(message, Mapping): raise TypeError( f'Expected message to be a mapping, but found {type(message)}') role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys()) @@ -123,7 +124,7 @@ def _get_role_key(message: Dict[str, str]) -> str: def _get_content_key(message: Dict[str, str]) -> str: - if not isinstance(message, Dict): + if not isinstance(message, Mapping): raise TypeError( f'Expected message to be a mapping, but found {type(message)}') content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys()) @@ -134,8 +135,8 @@ def _get_content_key(message: Dict[str, str]) -> str: return content_key -def _get_message_key(example: Dict[str, List[Dict[str, str]]]): - if not isinstance(example, Dict): +def _get_message_key(example: ChatFormattedDict): + if not isinstance(example, Mapping): raise TypeError( f'Expected example to be a mapping, but found {type(example)}') if len(example.keys()) != 1: @@ -148,7 +149,7 @@ def _get_message_key(example: Dict[str, List[Dict[str, str]]]): def _validate_chat_formatted_example(example: ChatFormattedDict): - if not isinstance(example, Dict): + if not isinstance(example, Mapping): raise TypeError( f'Expected example to be a mapping, but found {type(example)}') messages = example[_get_message_key(example)] From 035597432540587f86866b776371cb1ff854530b Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Feb 2024 22:31:23 +0000 Subject: [PATCH 15/19] fixed another one --- llmfoundry/data/finetuning/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 235b3f016f..e8040bee46 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -85,7 +85,7 @@ def _get_example_type(example: Example) -> ExampleType: Raises: KeyError: If the example type is unknown. """ - if not isinstance(example, Dict): + if not isinstance(example, Mapping): raise TypeError( f'Expected example to have dict-like, but found {type(example)}') if any(allowed_message_key in example From d7211520e16dade8ef56a0609f208100cf6c23cf Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 9 Feb 2024 16:09:04 +0000 Subject: [PATCH 16/19] updated message --- llmfoundry/data/finetuning/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index e8040bee46..a11a238521 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -87,7 +87,7 @@ def _get_example_type(example: Example) -> ExampleType: """ if not isinstance(example, Mapping): raise TypeError( - f'Expected example to have dict-like, but found {type(example)}') + f'Expected example to be a Mapping, but found {type(example)}') if any(allowed_message_key in example for allowed_message_key in _ALLOWED_MESSAGES_KEYS): return 'chat' From 3201258aee19ea0a192c42d08d102e552e055c36 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 9 Feb 2024 16:28:06 +0000 Subject: [PATCH 17/19] factored out duplicate functions --- llmfoundry/data/finetuning/tasks.py | 75 +++++++++-------------------- 1 file changed, 23 insertions(+), 52 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index a11a238521..adf0f3487d 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -38,8 +38,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from collections.abc import Mapping from functools import partial from pathlib import Path -from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union, - cast) +from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, + Union, cast) import datasets as hf_datasets import huggingface_hub as hf_hub @@ -112,53 +112,38 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 -def _get_role_key(message: Dict[str, str]) -> str: - if not isinstance(message, Mapping): +def _get_key(dictionary: Dict[str, Any], allowed_keys: Set[str]): + if not isinstance(dictionary, Mapping): raise TypeError( - f'Expected message to be a mapping, but found {type(message)}') - role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys()) - if len(role_keys) != 1: - raise ValueError(f'Expected 1 role key, but found {len(role_keys)}') - role_key = list(role_keys)[0] - return role_key - - -def _get_content_key(message: Dict[str, str]) -> str: - if not isinstance(message, Mapping): - raise TypeError( - f'Expected message to be a mapping, but found {type(message)}') - content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys()) - if len(content_keys) != 1: - raise ValueError( - f'Expected 1 content key, but found {len(content_keys)}') - content_key = list(content_keys)[0] - return content_key - - -def _get_message_key(example: ChatFormattedDict): - if not isinstance(example, Mapping): - raise TypeError( - f'Expected example to be a mapping, but found {type(example)}') - if len(example.keys()) != 1: + f'Expected dictionary to be a mapping, but found {type(dictionary)}' + ) + desired_keys = allowed_keys.intersection(dictionary.keys()) + if len(desired_keys) != 1: raise ValueError( - f'Expected 1 message key, but found {len(example.keys())}') - message_key = list(example.keys())[0] - if message_key not in _ALLOWED_MESSAGES_KEYS: - raise ValueError(f'Invalid message key: {message_key}') - return message_key + f'Dictionary has multiple keys in `allowed_keys`: {desired_keys}') + return list(desired_keys)[0] def _validate_chat_formatted_example(example: ChatFormattedDict): if not isinstance(example, Mapping): raise TypeError( f'Expected example to be a mapping, but found {type(example)}') - messages = example[_get_message_key(example)] + messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)] if not isinstance(messages, List): raise TypeError( f'Expected messages to be an iterable, but found {type(messages)}') + if len(messages) <= 1: + raise ValueError('Chat example must have at least two messages') + + last_message = messages[-1] + role_key = _get_key(last_message, _ALLOWED_ROLE_KEYS) + last_role = last_message[role_key] + if last_role not in _ALLOWED_LAST_MESSAGE_ROLES: + raise ValueError(f'Invalid last message role: {last_role}') + for message in messages: - role_key, content_key = _get_role_key(message), _get_content_key( - message) + role_key, content_key = _get_key(message, _ALLOWED_ROLE_KEYS), _get_key( + message, _ALLOWED_CONTENT_KEYS) if len(message.keys()) != 2: raise ValueError( f'Expected 2 keys in message, but found {len(message.keys())}') @@ -169,14 +154,6 @@ def _validate_chat_formatted_example(example: ChatFormattedDict): f'Expected content to be a string, but found {type(message[content_key])}' ) - if len(messages) <= 1: - raise ValueError('Chat example must have at least two messages') - last_message = messages[-1] - role_key = _get_role_key(last_message) - last_role = last_message[role_key] - if last_role not in _ALLOWED_LAST_MESSAGE_ROLES: - raise ValueError(f'Invalid last message role: {last_role}') - def _slice_chat_formatted_example( example: ChatFormattedDict, @@ -195,18 +172,12 @@ def _slice_chat_formatted_example( KeyError: If a message does not have a role or content. """ _validate_chat_formatted_example(example) - messages = example[_get_message_key(example)] + messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)] - if len(messages) < 2: - raise ValueError( - f'chat example must have at least two messages. {messages=}') last_message = messages[-1] if last_message['role'] != 'assistant': raise ValueError( f'last message must be from assistant. {last_message=}') - for message in messages: - if 'role' not in message or 'content' not in message: - raise KeyError(f'message must have role and content. {message=}') full_conversation = tokenizer.apply_chat_template(messages, tokenize=False) prompt = tokenizer.apply_chat_template(messages[:-1], From e84eabb73ea6bec66e2dd62b8b6cd6302f0d4d79 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 9 Feb 2024 18:29:45 +0000 Subject: [PATCH 18/19] dict -> mapping --- llmfoundry/data/finetuning/tasks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index adf0f3487d..3f9c2f99fe 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -66,8 +66,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: '.downloaded_finetuning')) SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] -PromptResponseDict = Dict[str, str] -ChatFormattedDict = Dict[str, List[Dict[str, str]]] +PromptResponseDict = Mapping[str, str] +ChatFormattedDict = Mapping[str, List[Dict[str, str]]] Example = Union[PromptResponseDict, ChatFormattedDict] ExampleType = Literal['prompt_response', 'chat'] TokenizedExample = Dict[str, List[int]] @@ -112,7 +112,7 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 -def _get_key(dictionary: Dict[str, Any], allowed_keys: Set[str]): +def _get_key(dictionary: Mapping[str, Any], allowed_keys: Set[str]): if not isinstance(dictionary, Mapping): raise TypeError( f'Expected dictionary to be a mapping, but found {type(dictionary)}' From 2dfa558d32cd63efa6456dbbfd320a0d0053a5d5 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Fri, 9 Feb 2024 23:42:14 +0000 Subject: [PATCH 19/19] add sequence --- llmfoundry/data/finetuning/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 99f5744f0e..126ed43812 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -38,8 +38,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from collections.abc import Mapping from functools import partial from pathlib import Path -from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, - Union, cast) +from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, Set, + Tuple, Union, cast) import datasets as hf_datasets import huggingface_hub as hf_hub