Skip to content

Commit

Permalink
Backport PR jupyterlab#560: Fixes lookup for custom chains
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonWeill authored and meeseeksmachine committed Jan 3, 2024
1 parent 5eb207b commit 7ea7cee
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,11 @@ def _append_exchange_openai(self, prompt: str, output: str):

def _decompose_model_id(self, model_id: str):
"""Breaks down a model ID into a two-tuple (provider_id, local_model_id). Returns (None, None) if indeterminate."""
if model_id in self.custom_model_registry:
# custom_model_registry maps keys to either a model name (a string) or an LLMChain.
# If this is an alias to another model, expand the full name of the model.
if model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[model_id], str
):
model_id = self.custom_model_registry[model_id]

return decompose_model_id(model_id, self.providers)
Expand Down Expand Up @@ -508,6 +512,17 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
)

provider_id, local_model_id = self._decompose_model_id(args.model_id)

# If this is a custom chain, send the message to the custom chain.
if args.model_id in self.custom_model_registry and isinstance(
self.custom_model_registry[args.model_id], LLMChain
):
return self.display_output(
self.custom_model_registry[args.model_id].run(prompt),
args.format,
{"jupyter_ai": {"custom_chain_id": args.model_id}},
)

Provider = self._get_provider(provider_id)
if Provider is None:
return TextOrMarkdown(
Expand Down

0 comments on commit 7ea7cee

Please sign in to comment.