Skip to content

Commit

Permalink
FEAT: added a ChatMessageNormalizer that formats messages in the temp…
Browse files Browse the repository at this point in the history
…late specified by a Hugging Face tokenizer (#128)

* added a ChatMessageNormalizer that formats messages in the template specified by a Hugging Face tokenizer

* added ChatMessageNormalizerTokenizerTemplate to chat message normalizer init file

* added tests for ChatMessageNormalizerTokenizerTemplate with three different Hugging Face tokenizers

* reran pre-commit hooks on normalizer tests

* split long strings across multiple lines to satisfy flake8 pre-commit hook

* added example usage of ChatMessageNormalizerTokenizerTemplate to chat_message notebook

* updated chat_message doc python file and ran jupytext

* added docstrings to ChatMessageNormalizerTokenizerTemplate class

* ran pre-commit hooks on chat_message file

* resolved mypy pre-commit hook and reran jupytext
  • Loading branch information
blakebullwinkel authored Mar 29, 2024
1 parent ec56d1d commit e252f4a
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 148 deletions.
347 changes: 199 additions & 148 deletions doc/code/memory/chat_message.ipynb
Original file line number Diff line number Diff line change
@@ -1,148 +1,199 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fe8f7647",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"### Introduction\n",
"\n",
"This notebook gives an introduction to the concept of `ChatMessage` and `ChatMessageNormalizer` and how it can be helpful as you start to work with different models.\n",
"\n",
"\n",
"The main format PyRIT works with is the `ChatMessage` paradigm. Any time a user wants to store or retrieve a chat message, they will use the `ChatMessage` object.\n",
"\n",
"However, different models may require different formats. For example, certain models may use chatml, or may not support system messages. This is handled\n",
"in from `ChatMessageNormalizer` and its subclasses.\n",
"\n",
"Below is an example that converts a list of chat messages to chatml format and back."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e12a1b6a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-18T19:55:24.275657Z",
"iopub.status.busy": "2024-03-18T19:55:24.275657Z",
"iopub.status.idle": "2024-03-18T19:55:24.481666Z",
"shell.execute_reply": "2024-03-18T19:55:24.480766Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|im_start|>system\n",
"You are a helpful AI assistant<|im_end|>\n",
"<|im_start|>user\n",
"Hello, how are you?<|im_end|>\n",
"<|im_start|>assistant\n",
"I'm doing well, thanks for asking.<|im_end|>\n",
"\n"
]
}
],
"source": [
"# Copyright (c) Microsoft Corporation.\n",
"# Licensed under the MIT license.\n",
"\n",
"from pyrit.models import ChatMessage\n",
"from pyrit.chat_message_normalizer import ChatMessageNormalizerChatML\n",
"\n",
"messages = [\n",
" ChatMessage(role=\"system\", content=\"You are a helpful AI assistant\"),\n",
" ChatMessage(role=\"user\", content=\"Hello, how are you?\"),\n",
" ChatMessage(role=\"assistant\", content=\"I'm doing well, thanks for asking.\"),\n",
"]\n",
"\n",
"normalizer = ChatMessageNormalizerChatML()\n",
"chatml_messages = normalizer.normalize(messages)\n",
"# chatml_messages is a string in chatml format\n",
"\n",
"print(chatml_messages)"
]
},
{
"cell_type": "markdown",
"id": "f5edca8b",
"metadata": {},
"source": [
"\n",
"If you wish you load a chatml-format conversation, you can use the `from_chatml` method in the `ChatMessageNormalizerChatML`. This will return a list of `ChatMessage` objects that you can then use."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "83bdae5e",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-18T19:55:24.485407Z",
"iopub.status.busy": "2024-03-18T19:55:24.485407Z",
"iopub.status.idle": "2024-03-18T19:55:24.498582Z",
"shell.execute_reply": "2024-03-18T19:55:24.496749Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ChatMessage(role='system', content='You are a helpful AI assistant', name=None, tool_calls=None, tool_call_id=None), ChatMessage(role='user', content='Hello, how are you?', name=None, tool_calls=None, tool_call_id=None), ChatMessage(role='assistant', content=\"I'm doing well, thanks for asking.\", name=None, tool_calls=None, tool_call_id=None)]\n"
]
}
],
"source": [
"\n",
"chat_messages = normalizer.from_chatml(\n",
" \"\"\"\\\n",
" <|im_start|>system\n",
" You are a helpful AI assistant<|im_end|>\n",
" <|im_start|>user\n",
" Hello, how are you?<|im_end|>\n",
" <|im_start|>assistant\n",
" I'm doing well, thanks for asking.<|im_end|>\"\"\"\n",
")\n",
"\n",
"print(chat_messages)"
]
},
{
"cell_type": "markdown",
"id": "8bccadfd",
"metadata": {},
"source": [
"To see how to use this in action, check out the [aml endpoint](./aml_endpoints.ipynb) notebook. It takes a `chat_message_normalizer` parameter so that an AML model can support various chat message formats."
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "pyrit-kernel",
"language": "python",
"name": "pyrit-kernel"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "markdown",
"id": "6b35a22f",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"### Introduction\n",
"\n",
"This notebook gives an introduction to the concept of `ChatMessage` and `ChatMessageNormalizer` and how it can be helpful as you start to work with different models.\n",
"\n",
"\n",
"The main format PyRIT works with is the `ChatMessage` paradigm. Any time a user wants to store or retrieve a chat message, they will use the `ChatMessage` object.\n",
"\n",
"However, different models may require different formats. For example, certain models may use chatml, or may not support system messages. This is handled\n",
"in from `ChatMessageNormalizer` and its subclasses.\n",
"\n",
"Below is an example that converts a list of chat messages to chatml format and back."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cfdb5cbf",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-29T21:27:55.175471Z",
"iopub.status.busy": "2024-03-29T21:27:55.175471Z",
"iopub.status.idle": "2024-03-29T21:27:59.086189Z",
"shell.execute_reply": "2024-03-29T21:27:59.085182Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|im_start|>system\n",
"You are a helpful AI assistant<|im_end|>\n",
"<|im_start|>user\n",
"Hello, how are you?<|im_end|>\n",
"<|im_start|>assistant\n",
"I'm doing well, thanks for asking.<|im_end|>\n",
"\n"
]
}
],
"source": [
"# Copyright (c) Microsoft Corporation.\n",
"# Licensed under the MIT license.\n",
"\n",
"from pyrit.models import ChatMessage\n",
"from pyrit.chat_message_normalizer import ChatMessageNormalizerChatML\n",
"\n",
"messages = [\n",
" ChatMessage(role=\"system\", content=\"You are a helpful AI assistant\"),\n",
" ChatMessage(role=\"user\", content=\"Hello, how are you?\"),\n",
" ChatMessage(role=\"assistant\", content=\"I'm doing well, thanks for asking.\"),\n",
"]\n",
"\n",
"normalizer = ChatMessageNormalizerChatML()\n",
"chatml_messages = normalizer.normalize(messages)\n",
"# chatml_messages is a string in chatml format\n",
"\n",
"print(chatml_messages)"
]
},
{
"cell_type": "markdown",
"id": "773b9e41",
"metadata": {},
"source": [
"\n",
"If you wish you load a chatml-format conversation, you can use the `from_chatml` method in the `ChatMessageNormalizerChatML`. This will return a list of `ChatMessage` objects that you can then use."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7bbef6a1",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-29T21:27:59.089740Z",
"iopub.status.busy": "2024-03-29T21:27:59.089740Z",
"iopub.status.idle": "2024-03-29T21:27:59.101747Z",
"shell.execute_reply": "2024-03-29T21:27:59.100746Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ChatMessage(role='system', content='You are a helpful AI assistant', name=None, tool_calls=None, tool_call_id=None), ChatMessage(role='user', content='Hello, how are you?', name=None, tool_calls=None, tool_call_id=None), ChatMessage(role='assistant', content=\"I'm doing well, thanks for asking.\", name=None, tool_calls=None, tool_call_id=None)]\n"
]
}
],
"source": [
"\n",
"chat_messages = normalizer.from_chatml(\n",
" \"\"\"\\\n",
" <|im_start|>system\n",
" You are a helpful AI assistant<|im_end|>\n",
" <|im_start|>user\n",
" Hello, how are you?<|im_end|>\n",
" <|im_start|>assistant\n",
" I'm doing well, thanks for asking.<|im_end|>\"\"\"\n",
")\n",
"\n",
"print(chat_messages)"
]
},
{
"cell_type": "markdown",
"id": "6f803626",
"metadata": {},
"source": [
"To see how to use this in action, check out the [aml endpoint](./aml_endpoints.ipynb) notebook. It takes a `chat_message_normalizer` parameter so that an AML model can support various chat message formats."
]
},
{
"cell_type": "markdown",
"id": "c97cf4c4",
"metadata": {},
"source": [
"Besides chatml, there are many other chat templates that a model might be trained on. If you would like to apply the template stored in a Hugging Face tokenizer,\n",
"you can utilize `ChatMessageNormalizerTokenizerTemplate`. In the example below, we load the tokenizer for Mistral-7B-Instruct-v0.1 and apply its chat template to\n",
"the messages. Note that this template only adds `[INST]` and `[/INST]` tokens to the user messages for instruction fine-tuning."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1e674217",
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-29T21:27:59.104752Z",
"iopub.status.busy": "2024-03-29T21:27:59.104752Z",
"iopub.status.idle": "2024-03-29T21:27:59.795124Z",
"shell.execute_reply": "2024-03-29T21:27:59.794119Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<s>[INST] Hello, how are you? [/INST]I'm doing well, thanks for asking.</s> [INST] What is your favorite food? [/INST]\n"
]
}
],
"source": [
"from pyrit.chat_message_normalizer import ChatMessageNormalizerTokenizerTemplate\n",
"from transformers import AutoTokenizer\n",
"\n",
"messages = [\n",
" ChatMessage(role=\"user\", content=\"Hello, how are you?\"),\n",
" ChatMessage(role=\"assistant\", content=\"I'm doing well, thanks for asking.\"),\n",
" ChatMessage(role=\"user\", content=\"What is your favorite food?\"),\n",
"]\n",
"\n",
"# load the tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.1\")\n",
"\n",
"# create the normalizer and pass in the tokenizer\n",
"tokenizer_normalizer = ChatMessageNormalizerTokenizerTemplate(tokenizer)\n",
"\n",
"tokenizer_template_messages = tokenizer_normalizer.normalize(messages)\n",
"print(tokenizer_template_messages)"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "pyrit_kernel",
"language": "python",
"name": "pyrit_kernel"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
24 changes: 24 additions & 0 deletions doc/code/memory/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,27 @@

# %% [markdown]
# To see how to use this in action, check out the [aml endpoint](./aml_endpoints.ipynb) notebook. It takes a `chat_message_normalizer` parameter so that an AML model can support various chat message formats.

# %% [markdown]
# Besides chatml, there are many other chat templates that a model might be trained on. If you would like to apply the template stored in a Hugging Face tokenizer,
# you can utilize `ChatMessageNormalizerTokenizerTemplate`. In the example below, we load the tokenizer for Mistral-7B-Instruct-v0.1 and apply its chat template to
# the messages. Note that this template only adds `[INST]` and `[/INST]` tokens to the user messages for instruction fine-tuning.

# %%
from pyrit.chat_message_normalizer import ChatMessageNormalizerTokenizerTemplate
from transformers import AutoTokenizer

messages = [
ChatMessage(role="user", content="Hello, how are you?"),
ChatMessage(role="assistant", content="I'm doing well, thanks for asking."),
ChatMessage(role="user", content="What is your favorite food?"),
]

# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

# create the normalizer and pass in the tokenizer
tokenizer_normalizer = ChatMessageNormalizerTokenizerTemplate(tokenizer)

tokenizer_template_messages = tokenizer_normalizer.normalize(messages)
print(tokenizer_template_messages)
2 changes: 2 additions & 0 deletions pyrit/chat_message_normalizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from pyrit.chat_message_normalizer.chat_message_nop import ChatMessageNop
from pyrit.chat_message_normalizer.generic_system_squash import GenericSystemSquash
from pyrit.chat_message_normalizer.chat_message_normalizer_chatml import ChatMessageNormalizerChatML
from pyrit.chat_message_normalizer.chat_message_normalizer_tokenizer import ChatMessageNormalizerTokenizerTemplate

__all__ = [
"ChatMessageNormalizer",
"ChatMessageNop",
"GenericSystemSquash",
"ChatMessageNormalizerChatML",
"ChatMessageNormalizerTokenizerTemplate",
]
Loading

0 comments on commit e252f4a

Please sign in to comment.