Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

strengthened chat formatting validation #960

Merged
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}
_ALLOWED_MESSAGES_KEYS = {'messages'}
_ALLOWED_ROLE_KEYS = {'role'}
_ALLOWED_CONTENT_KEYS = {'content'}
_ALLOWED_ROLES = {'user', 'assistant', 'system'}
_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(
os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir,
'.downloaded_finetuning'))
Expand Down Expand Up @@ -102,6 +107,46 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0


def _get_role_key(message: Dict[str, str]) -> str:
role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys())
assert len(
milocress marked this conversation as resolved.
Show resolved Hide resolved
role_keys) == 1, f'Expected 1 role key, but found {len(role_keys)}'
role_key = list(role_keys)[0]
return role_key


def _get_content_key(message: Dict[str, str]) -> str:
milocress marked this conversation as resolved.
Show resolved Hide resolved
content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys())
assert len(content_keys
) == 1, f'Expected 1 content key, but found {len(content_keys)}'
content_key = list(content_keys)[0]
return content_key


def _get_message_key(example: ChatFormattedDict):
milocress marked this conversation as resolved.
Show resolved Hide resolved
assert len(example.keys()
) == 1, f'Expected 1 message key, but found {len(example.keys())}'
message_key = list(example.keys())[0]
assert message_key in _ALLOWED_MESSAGES_KEYS, f'Invalid message key: {message_key}'
return message_key


def _validate_chat_formatted_example(example: ChatFormattedDict):
messages = example[_get_message_key(example)]
for message in messages:
assert len(message.keys(
)) == 2, f'Expected 2 keys in message, but found {len(message.keys())}'
role_key, _ = _get_role_key(message), _get_content_key(message)
assert message[
role_key] in _ALLOWED_ROLES, f'Invalid role: {message[role_key]}'

assert len(messages) > 1, 'Chat example must have at least two messages'
last_message = messages[-1]
role_key = _get_role_key(last_message)
last_role = last_message[role_key]
assert last_role in _ALLOWED_LAST_MESSAGE_ROLES, f'Invalid last message role: {last_role}'


def _slice_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
Expand All @@ -118,7 +163,8 @@ def _slice_chat_formatted_example(
ValueError: If the chat example has less than two messages or if the last message is not from the assistant.
KeyError: If a message does not have a role or content.
"""
messages = example['messages']
_validate_chat_formatted_example(example)
milocress marked this conversation as resolved.
Show resolved Hide resolved
messages = example[_get_message_key(example)]

if len(messages) < 2:
raise ValueError(
Expand Down
Loading