Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

strengthened chat formatting validation #960

Merged
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,37 +109,42 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:

def _get_role_key(message: Dict[str, str]) -> str:
role_keys = _ALLOWED_ROLE_KEYS.intersection(message.keys())
assert len(role_keys) == 1
assert len(
milocress marked this conversation as resolved.
Show resolved Hide resolved
role_keys) == 1, f'Expected 1 role key, but found {len(role_keys)}'
role_key = list(role_keys)[0]
return role_key


def _get_content_key(message: Dict[str, str]) -> str:
milocress marked this conversation as resolved.
Show resolved Hide resolved
content_keys = _ALLOWED_CONTENT_KEYS.intersection(message.keys())
assert len(content_keys) == 1
assert len(content_keys
) == 1, f'Expected 1 content key, but found {len(content_keys)}'
content_key = list(content_keys)[0]
return content_key


def _get_message_key(example: ChatFormattedDict):
milocress marked this conversation as resolved.
Show resolved Hide resolved
assert len(example.keys()) == 1
assert len(example.keys()
) == 1, f'Expected 1 message key, but found {len(example.keys())}'
message_key = list(example.keys())[0]
assert message_key in _ALLOWED_MESSAGES_KEYS
assert message_key in _ALLOWED_MESSAGES_KEYS, f'Invalid message key: {message_key}'
return message_key


def _validate_chat_formatted_example(example: ChatFormattedDict):
messages = example[_get_message_key(example)]
for message in messages:
assert len(message.keys()) == 2
assert len(message.keys(
)) == 2, f'Expected 2 keys in message, but found {len(message.keys())}'
role_key, _ = _get_role_key(message), _get_content_key(message)
assert message[role_key] in _ALLOWED_ROLES
assert message[
role_key] in _ALLOWED_ROLES, f'Invalid role: {message[role_key]}'

assert len(messages) > 1
assert len(messages) > 1, 'Chat example must have at least two messages'
last_message = messages[-1]
role_key = _get_role_key(last_message)
last_role = last_message[role_key]
assert last_role in _ALLOWED_LAST_MESSAGE_ROLES
assert last_role in _ALLOWED_LAST_MESSAGE_ROLES, f'Invalid last message role: {last_role}'


def _slice_chat_formatted_example(
Expand Down
Loading