Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds prompt templates #53

Merged
merged 2 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/magics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@
],
"source": [
"%%ai chatgpt -f math\n",
"Generate the 2D heat equation in LaTeX surrounded by `$$`. Do not include an explanation."
"Generate the 2D heat equation."
]
},
{
Expand Down Expand Up @@ -487,8 +487,7 @@
],
"source": [
"%%ai j2-jumbo-instruct --format math\n",
"Write the 2d Laplace equation in polar coordinates in pure LaTeX, delimited by `$$`.\n",
"Do not include an explanation."
"Write the 2d Laplace equation in polar coordinates."
]
},
{
Expand Down
14 changes: 14 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
"raw": None
}

MARKDOWN_PROMPT_TEMPLATE = '{prompt}\n\nProduce output in markdown format only.'

PROMPT_TEMPLATES_BY_FORMAT = {
"html": '{prompt}\n\nProduce output in HTML format only, with no markup before or afterward.',
"markdown": MARKDOWN_PROMPT_TEMPLATE,
"md": MARKDOWN_PROMPT_TEMPLATE,
"math": '{prompt}\n\nProduce output in LaTeX format only, with $$ at the beginning and end.',
"json": '{prompt}\n\nProduce output in JSON format only, with nothing before or after it.',
"raw": '{prompt}' # No customization
}

class FormatDict(dict):
"""Subclass of dict to be passed to str#format(). Suppresses KeyError and
leaves replacement field unchanged if replacement field is not associated
Expand Down Expand Up @@ -128,6 +139,9 @@ def ai(self, line, cell=None):
else:
prompt = cell

# Apply a prompt template.
prompt = PROMPT_TEMPLATES_BY_FORMAT[args.format].format(prompt = prompt)

# determine provider and local model IDs
provider_id, local_model_id = self._decompose_model_id(args.model_id)
Provider = self._get_provider(provider_id)
Expand Down