Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

strengthened chat formatting validation #960

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import logging
import os
import warnings
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence,
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, Set,
Tuple, Union, cast)

import datasets as hf_datasets
Expand All @@ -55,13 +56,18 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}
_ALLOWED_MESSAGES_KEYS = {'messages'}
_ALLOWED_ROLE_KEYS = {'role'}
_ALLOWED_CONTENT_KEYS = {'content'}
_ALLOWED_ROLES = {'user', 'assistant', 'system'}
_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(
os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir,
'.downloaded_finetuning'))
SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']

PromptResponseDict = Dict[str, str]
ChatFormattedDict = Dict[str, List[Dict[str, str]]]
PromptResponseDict = Mapping[str, str]
ChatFormattedDict = Mapping[str, List[Dict[str, str]]]
Example = Union[PromptResponseDict, ChatFormattedDict]
ExampleType = Literal['prompt_response', 'chat']
TokenizedExample = Dict[str, List[int]]
Expand All @@ -79,7 +85,11 @@ def _get_example_type(example: Example) -> ExampleType:
Raises:
KeyError: If the example type is unknown.
"""
if 'messages' in example:
if not isinstance(example, Mapping):
raise TypeError(
milocress marked this conversation as resolved.
Show resolved Hide resolved
f'Expected example to be a Mapping, but found {type(example)}')
if any(allowed_message_key in example
for allowed_message_key in _ALLOWED_MESSAGES_KEYS):
return 'chat'
elif any([
pr in example
Expand All @@ -102,6 +112,49 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0


def _get_key(dictionary: Mapping[str, Any], allowed_keys: Set[str]):
if not isinstance(dictionary, Mapping):
raise TypeError(
f'Expected dictionary to be a mapping, but found {type(dictionary)}'
)
desired_keys = allowed_keys.intersection(dictionary.keys())
if len(desired_keys) != 1:
raise ValueError(
f'Dictionary has multiple keys in `allowed_keys`: {desired_keys}')
return list(desired_keys)[0]


def _validate_chat_formatted_example(example: ChatFormattedDict):
if not isinstance(example, Mapping):
raise TypeError(
f'Expected example to be a mapping, but found {type(example)}')
messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]
if not isinstance(messages, List):
raise TypeError(
f'Expected messages to be an iterable, but found {type(messages)}')
if len(messages) <= 1:
raise ValueError('Chat example must have at least two messages')

last_message = messages[-1]
role_key = _get_key(last_message, _ALLOWED_ROLE_KEYS)
last_role = last_message[role_key]
if last_role not in _ALLOWED_LAST_MESSAGE_ROLES:
raise ValueError(f'Invalid last message role: {last_role}')

for message in messages:
role_key, content_key = _get_key(message, _ALLOWED_ROLE_KEYS), _get_key(
message, _ALLOWED_CONTENT_KEYS)
if len(message.keys()) != 2:
raise ValueError(
f'Expected 2 keys in message, but found {len(message.keys())}')
if message[role_key] not in _ALLOWED_ROLES:
raise ValueError(f'Invalid role: {message[role_key]}')
if not isinstance(message[content_key], str):
raise TypeError(
f'Expected content to be a string, but found {type(message[content_key])}'
)


def _slice_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
Expand All @@ -118,18 +171,13 @@ def _slice_chat_formatted_example(
ValueError: If the chat example has less than two messages or if the last message is not from the assistant.
KeyError: If a message does not have a role or content.
"""
messages = example['messages']
_validate_chat_formatted_example(example)
milocress marked this conversation as resolved.
Show resolved Hide resolved
messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]

if len(messages) < 2:
raise ValueError(
f'chat example must have at least two messages. {messages=}')
last_message = messages[-1]
if last_message['role'] != 'assistant':
raise ValueError(
f'last message must be from assistant. {last_message=}')
for message in messages:
if 'role' not in message or 'content' not in message:
raise KeyError(f'message must have role and content. {message=}')

full_conversation = tokenizer.apply_chat_template(messages, tokenize=False)
prompt = tokenizer.apply_chat_template(messages[:-1],
Expand Down
4 changes: 3 additions & 1 deletion tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def test_tokenize_chat_example_malformed():
'content': 'user message not followed by an assistant label'
}]
}
wrong_type = {'messages': 'this is not a list of messages'}
malformed_chat_examples = [
too_few_messages, no_content, ends_with_user_role, no_assistant_message
too_few_messages, no_content, ends_with_user_role, no_assistant_message,
wrong_type
]
my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {})
for example in malformed_chat_examples:
Expand Down
Loading