-
Notifications
You must be signed in to change notification settings - Fork 339
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FEAT: added a ChatMessageNormalizer that formats messages in the temp…
…late specified by a Hugging Face tokenizer (#128) * added a ChatMessageNormalizer that formats messages in the template specified by a Hugging Face tokenizer * added ChatMessageNormalizerTokenizerTemplate to chat message normalizer init file * added tests for ChatMessageNormalizerTokenizerTemplate with three different Hugging Face tokenizers * reran pre-commit hooks on normalizer tests * split long strings across multiple lines to satisfy flake8 pre-commit hook * added example usage of ChatMessageNormalizerTokenizerTemplate to chat_message notebook * updated chat_message doc python file and ran jupytext * added docstrings to ChatMessageNormalizerTokenizerTemplate class * ran pre-commit hooks on chat_message file * resolved mypy pre-commit hook and reran jupytext
- Loading branch information
1 parent
ec56d1d
commit e252f4a
Showing
5 changed files
with
341 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,148 +1,199 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "fe8f7647", | ||
"metadata": { | ||
"lines_to_next_cell": 0 | ||
}, | ||
"source": [ | ||
"### Introduction\n", | ||
"\n", | ||
"This notebook gives an introduction to the concept of `ChatMessage` and `ChatMessageNormalizer` and how it can be helpful as you start to work with different models.\n", | ||
"\n", | ||
"\n", | ||
"The main format PyRIT works with is the `ChatMessage` paradigm. Any time a user wants to store or retrieve a chat message, they will use the `ChatMessage` object.\n", | ||
"\n", | ||
"However, different models may require different formats. For example, certain models may use chatml, or may not support system messages. This is handled\n", | ||
"in from `ChatMessageNormalizer` and its subclasses.\n", | ||
"\n", | ||
"Below is an example that converts a list of chat messages to chatml format and back." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "e12a1b6a", | ||
"metadata": { | ||
"execution": { | ||
"iopub.execute_input": "2024-03-18T19:55:24.275657Z", | ||
"iopub.status.busy": "2024-03-18T19:55:24.275657Z", | ||
"iopub.status.idle": "2024-03-18T19:55:24.481666Z", | ||
"shell.execute_reply": "2024-03-18T19:55:24.480766Z" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<|im_start|>system\n", | ||
"You are a helpful AI assistant<|im_end|>\n", | ||
"<|im_start|>user\n", | ||
"Hello, how are you?<|im_end|>\n", | ||
"<|im_start|>assistant\n", | ||
"I'm doing well, thanks for asking.<|im_end|>\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Copyright (c) Microsoft Corporation.\n", | ||
"# Licensed under the MIT license.\n", | ||
"\n", | ||
"from pyrit.models import ChatMessage\n", | ||
"from pyrit.chat_message_normalizer import ChatMessageNormalizerChatML\n", | ||
"\n", | ||
"messages = [\n", | ||
" ChatMessage(role=\"system\", content=\"You are a helpful AI assistant\"),\n", | ||
" ChatMessage(role=\"user\", content=\"Hello, how are you?\"),\n", | ||
" ChatMessage(role=\"assistant\", content=\"I'm doing well, thanks for asking.\"),\n", | ||
"]\n", | ||
"\n", | ||
"normalizer = ChatMessageNormalizerChatML()\n", | ||
"chatml_messages = normalizer.normalize(messages)\n", | ||
"# chatml_messages is a string in chatml format\n", | ||
"\n", | ||
"print(chatml_messages)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f5edca8b", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"If you wish you load a chatml-format conversation, you can use the `from_chatml` method in the `ChatMessageNormalizerChatML`. This will return a list of `ChatMessage` objects that you can then use." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "83bdae5e", | ||
"metadata": { | ||
"execution": { | ||
"iopub.execute_input": "2024-03-18T19:55:24.485407Z", | ||
"iopub.status.busy": "2024-03-18T19:55:24.485407Z", | ||
"iopub.status.idle": "2024-03-18T19:55:24.498582Z", | ||
"shell.execute_reply": "2024-03-18T19:55:24.496749Z" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[ChatMessage(role='system', content='You are a helpful AI assistant', name=None, tool_calls=None, tool_call_id=None), ChatMessage(role='user', content='Hello, how are you?', name=None, tool_calls=None, tool_call_id=None), ChatMessage(role='assistant', content=\"I'm doing well, thanks for asking.\", name=None, tool_calls=None, tool_call_id=None)]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"\n", | ||
"chat_messages = normalizer.from_chatml(\n", | ||
" \"\"\"\\\n", | ||
" <|im_start|>system\n", | ||
" You are a helpful AI assistant<|im_end|>\n", | ||
" <|im_start|>user\n", | ||
" Hello, how are you?<|im_end|>\n", | ||
" <|im_start|>assistant\n", | ||
" I'm doing well, thanks for asking.<|im_end|>\"\"\"\n", | ||
")\n", | ||
"\n", | ||
"print(chat_messages)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8bccadfd", | ||
"metadata": {}, | ||
"source": [ | ||
"To see how to use this in action, check out the [aml endpoint](./aml_endpoints.ipynb) notebook. It takes a `chat_message_normalizer` parameter so that an AML model can support various chat message formats." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"jupytext": { | ||
"cell_metadata_filter": "-all" | ||
}, | ||
"kernelspec": { | ||
"display_name": "pyrit-kernel", | ||
"language": "python", | ||
"name": "pyrit-kernel" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6b35a22f", | ||
"metadata": { | ||
"lines_to_next_cell": 0 | ||
}, | ||
"source": [ | ||
"### Introduction\n", | ||
"\n", | ||
"This notebook gives an introduction to the concept of `ChatMessage` and `ChatMessageNormalizer` and how it can be helpful as you start to work with different models.\n", | ||
"\n", | ||
"\n", | ||
"The main format PyRIT works with is the `ChatMessage` paradigm. Any time a user wants to store or retrieve a chat message, they will use the `ChatMessage` object.\n", | ||
"\n", | ||
"However, different models may require different formats. For example, certain models may use chatml, or may not support system messages. This is handled\n", | ||
"in from `ChatMessageNormalizer` and its subclasses.\n", | ||
"\n", | ||
"Below is an example that converts a list of chat messages to chatml format and back." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "cfdb5cbf", | ||
"metadata": { | ||
"execution": { | ||
"iopub.execute_input": "2024-03-29T21:27:55.175471Z", | ||
"iopub.status.busy": "2024-03-29T21:27:55.175471Z", | ||
"iopub.status.idle": "2024-03-29T21:27:59.086189Z", | ||
"shell.execute_reply": "2024-03-29T21:27:59.085182Z" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<|im_start|>system\n", | ||
"You are a helpful AI assistant<|im_end|>\n", | ||
"<|im_start|>user\n", | ||
"Hello, how are you?<|im_end|>\n", | ||
"<|im_start|>assistant\n", | ||
"I'm doing well, thanks for asking.<|im_end|>\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Copyright (c) Microsoft Corporation.\n", | ||
"# Licensed under the MIT license.\n", | ||
"\n", | ||
"from pyrit.models import ChatMessage\n", | ||
"from pyrit.chat_message_normalizer import ChatMessageNormalizerChatML\n", | ||
"\n", | ||
"messages = [\n", | ||
" ChatMessage(role=\"system\", content=\"You are a helpful AI assistant\"),\n", | ||
" ChatMessage(role=\"user\", content=\"Hello, how are you?\"),\n", | ||
" ChatMessage(role=\"assistant\", content=\"I'm doing well, thanks for asking.\"),\n", | ||
"]\n", | ||
"\n", | ||
"normalizer = ChatMessageNormalizerChatML()\n", | ||
"chatml_messages = normalizer.normalize(messages)\n", | ||
"# chatml_messages is a string in chatml format\n", | ||
"\n", | ||
"print(chatml_messages)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "773b9e41", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"If you wish you load a chatml-format conversation, you can use the `from_chatml` method in the `ChatMessageNormalizerChatML`. This will return a list of `ChatMessage` objects that you can then use." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "7bbef6a1", | ||
"metadata": { | ||
"execution": { | ||
"iopub.execute_input": "2024-03-29T21:27:59.089740Z", | ||
"iopub.status.busy": "2024-03-29T21:27:59.089740Z", | ||
"iopub.status.idle": "2024-03-29T21:27:59.101747Z", | ||
"shell.execute_reply": "2024-03-29T21:27:59.100746Z" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[ChatMessage(role='system', content='You are a helpful AI assistant', name=None, tool_calls=None, tool_call_id=None), ChatMessage(role='user', content='Hello, how are you?', name=None, tool_calls=None, tool_call_id=None), ChatMessage(role='assistant', content=\"I'm doing well, thanks for asking.\", name=None, tool_calls=None, tool_call_id=None)]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"\n", | ||
"chat_messages = normalizer.from_chatml(\n", | ||
" \"\"\"\\\n", | ||
" <|im_start|>system\n", | ||
" You are a helpful AI assistant<|im_end|>\n", | ||
" <|im_start|>user\n", | ||
" Hello, how are you?<|im_end|>\n", | ||
" <|im_start|>assistant\n", | ||
" I'm doing well, thanks for asking.<|im_end|>\"\"\"\n", | ||
")\n", | ||
"\n", | ||
"print(chat_messages)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6f803626", | ||
"metadata": {}, | ||
"source": [ | ||
"To see how to use this in action, check out the [aml endpoint](./aml_endpoints.ipynb) notebook. It takes a `chat_message_normalizer` parameter so that an AML model can support various chat message formats." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c97cf4c4", | ||
"metadata": {}, | ||
"source": [ | ||
"Besides chatml, there are many other chat templates that a model might be trained on. If you would like to apply the template stored in a Hugging Face tokenizer,\n", | ||
"you can utilize `ChatMessageNormalizerTokenizerTemplate`. In the example below, we load the tokenizer for Mistral-7B-Instruct-v0.1 and apply its chat template to\n", | ||
"the messages. Note that this template only adds `[INST]` and `[/INST]` tokens to the user messages for instruction fine-tuning." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "1e674217", | ||
"metadata": { | ||
"execution": { | ||
"iopub.execute_input": "2024-03-29T21:27:59.104752Z", | ||
"iopub.status.busy": "2024-03-29T21:27:59.104752Z", | ||
"iopub.status.idle": "2024-03-29T21:27:59.795124Z", | ||
"shell.execute_reply": "2024-03-29T21:27:59.794119Z" | ||
} | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<s>[INST] Hello, how are you? [/INST]I'm doing well, thanks for asking.</s> [INST] What is your favorite food? [/INST]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from pyrit.chat_message_normalizer import ChatMessageNormalizerTokenizerTemplate\n", | ||
"from transformers import AutoTokenizer\n", | ||
"\n", | ||
"messages = [\n", | ||
" ChatMessage(role=\"user\", content=\"Hello, how are you?\"),\n", | ||
" ChatMessage(role=\"assistant\", content=\"I'm doing well, thanks for asking.\"),\n", | ||
" ChatMessage(role=\"user\", content=\"What is your favorite food?\"),\n", | ||
"]\n", | ||
"\n", | ||
"# load the tokenizer\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.1\")\n", | ||
"\n", | ||
"# create the normalizer and pass in the tokenizer\n", | ||
"tokenizer_normalizer = ChatMessageNormalizerTokenizerTemplate(tokenizer)\n", | ||
"\n", | ||
"tokenizer_template_messages = tokenizer_normalizer.normalize(messages)\n", | ||
"print(tokenizer_template_messages)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"jupytext": { | ||
"cell_metadata_filter": "-all" | ||
}, | ||
"kernelspec": { | ||
"display_name": "pyrit_kernel", | ||
"language": "python", | ||
"name": "pyrit_kernel" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.