Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds multi-environment variable authentication, Baidu Qianfan ERNIE-bot provider #531

Merged
merged 11 commits into from
Dec 21, 2023
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,5 @@ dev.sh
.jupyter_ystore.db

.yarn

.conda/
4 changes: 3 additions & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,23 @@ Jupyter AI supports a wide range of model providers and models. To use Jupyter A

Jupyter AI supports the following model providers:

| Provider | Provider ID | Environment variable | Python package(s) |
| Provider | Provider ID | Environment variable(s) | Python package(s) |
|---------------------|----------------------|----------------------------|---------------------------------|
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` |
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `anthropic` |
| Bedrock | `bedrock` | N/A | `boto3` |
| Bedrock (chat) | `bedrock-chat` | N/A | `boto3` |
| Cohere | `cohere` | `COHERE_API_KEY` | `cohere` |
| ERNIE-Bot | `qianfan` | `QIANFAN_AK`, `QIANFAN_SK` | `qianfan` |
| GPT4All | `gpt4all` | N/A | `gpt4all` |
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `openai` |
| SageMaker | `sagemaker-endpoint` | N/A | `boto3` |

The environment variable names shown above are also the names of the settings keys used when setting up the chat interface.
If multiple variables are listed for a provider, **all** must be specified.

To use the Bedrock models, you need access to the Bedrock service. For more information, see the
[Amazon Bedrock Homepage](https://aws.amazon.com/bedrock/).
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
GPT4AllEmbeddingsProvider,
HfHubEmbeddingsProvider,
OpenAIEmbeddingsProvider,
QianfanEmbeddingsEndpointProvider,
)
from .exception import store_exception
from .magics import AiMagics
Expand All @@ -27,6 +28,7 @@
GPT4AllProvider,
HfHubProvider,
OpenAIProvider,
QianfanProvider,
SmEndpointProvider,
)

Expand Down
3 changes: 3 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
"gpt3": "openai:text-davinci-003",
"chatgpt": "openai-chat:gpt-3.5-turbo",
"gpt4": "openai-chat:gpt-4",
"ernie-bot": "qianfan:ERNIE-Bot",
"ernie-bot-4": "qianfan:ERNIE-Bot-4",
"titan": "bedrock:amazon.titan-tg1-large",
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
AwsAuthStrategy,
EnvAuthStrategy,
Field,
MultiEnvAuthStrategy,
)
from langchain.embeddings import (
BedrockEmbeddings,
CohereEmbeddings,
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
QianfanEmbeddingsEndpoint,
)
from langchain.pydantic_v1 import BaseModel, Extra

Expand Down Expand Up @@ -127,3 +129,14 @@ def __init__(self, **kwargs):
models = ["all-MiniLM-L6-v2-f16"]
model_id_key = "model_id"
pypi_package_deps = ["gpt4all"]


class QianfanEmbeddingsEndpointProvider(
BaseEmbeddingsProvider, QianfanEmbeddingsEndpoint
):
id = "qianfan"
name = "ERNIE-Bot"
models = ["ERNIE-Bot", "ERNIE-Bot-4"]
model_id_key = "model"
pypi_package_deps = ["qianfan"]
auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"])
97 changes: 60 additions & 37 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from IPython import get_ipython
from IPython.core.magic import Magics, line_cell_magic, magics_class
from IPython.display import HTML, JSON, Markdown, Math
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers
from langchain.chains import LLMChain
from langchain.schema import HumanMessage
Expand All @@ -28,14 +29,6 @@
)
from .providers import BaseProvider

MODEL_ID_ALIASES = {
"gpt2": "huggingface_hub:gpt2",
"gpt3": "openai:text-davinci-003",
"chatgpt": "openai-chat:gpt-3.5-turbo",
"gpt4": "openai-chat:gpt-4",
"titan": "bedrock:amazon.titan-tg1-large",
}


class TextOrMarkdown:
def __init__(self, text, markdown):
Expand Down Expand Up @@ -95,6 +88,18 @@ def _repr_mimebundle_(self, include=None, exclude=None):

AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"}

# Strings for listing providers and models
# Avoid composing strings, to make localization easier in the future
ENV_NOT_SET = "You have not set this environment variable, so you cannot use this provider's models."
ENV_SET = (
"You have set this environment variable, so you can use this provider's models."
)
MULTIENV_NOT_SET = "You have not set all of these environment variables, so you cannot use this provider's models."
MULTIENV_SET = "You have set all of these environment variables, so you can use this provider's models."

ENV_REQUIRES = "Requires environment variable:"
MULTIENV_REQUIRES = "Requires environment variables:"


class FormatDict(dict):
"""Subclass of dict to be passed to str#format(). Suppresses KeyError and
Expand Down Expand Up @@ -177,44 +182,53 @@ def _ai_env_status_for_provider_markdown(self, provider_id):
):
return na_message # No emoji

try:
env_var = self.providers[provider_id].auth_strategy.name
except AttributeError: # No "name" attribute
not_set_title = ENV_NOT_SET
set_title = ENV_SET
env_status_ok = False

auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy.type == "env":
var_name = auth_strategy.name
env_var_display = f"`{var_name}`"
env_status_ok = var_name in os.environ
elif auth_strategy.type == "multienv":
# Check multiple environment variables
var_names = self.providers[provider_id].auth_strategy.names
formatted_names = [f"`{name}`" for name in var_names]
env_var_display = ", ".join(formatted_names)
env_status_ok = all(var_name in os.environ for var_name in var_names)
not_set_title = MULTIENV_NOT_SET
set_title = MULTIENV_SET
else: # No environment variables
return na_message

output = f"`{env_var}` | "
if os.getenv(env_var) == None:
output += (
'<abbr title="You have not set this environment variable, '
+ "so you cannot use this provider's models.\">❌</abbr>"
)
output = f"{env_var_display} | "
if env_status_ok:
output += f'<abbr title="{set_title}">✅</abbr>'
else:
output += (
'<abbr title="You have set this environment variable, '
+ "so you can use this provider's models.\">✅</abbr>"
)
output += f'<abbr title="{not_set_title}">❌</abbr>'

return output

def _ai_env_status_for_provider_text(self, provider_id):
if (
provider_id not in self.providers
or self.providers[provider_id].auth_strategy == None
# only handle providers with "env" or "multienv" auth strategy
auth_strategy = getattr(self.providers[provider_id], "auth_strategy", None)
if not auth_strategy or (
auth_strategy.type != "env" and auth_strategy.type != "multienv"
):
return "" # No message necessary

try:
env_var = self.providers[provider_id].auth_strategy.name
except AttributeError: # No "name" attribute
return ""

output = f"Requires environment variable {env_var} "
if os.getenv(env_var) != None:
output += "(set)"
else:
output += "(not set)"
prefix = ENV_REQUIRES if auth_strategy.type == "env" else MULTIENV_REQUIRES
envvars = (
[auth_strategy.name]
if auth_strategy.type == "env"
else auth_strategy.names[:]
)

for i in range(len(envvars)):
envvars[i] += " (set)" if envvars[i] in os.environ else " (not set)"

return output + "\n"
return prefix + " " + ", ".join(envvars) + "\n"

# Is this a name of a Python variable that can be called as a LangChain chain?
def _is_langchain_chain(self, name):
Expand Down Expand Up @@ -493,13 +507,22 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
# validate presence of authn credentials
auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy:
# TODO: handle auth strategies besides EnvAuthStrategy
if auth_strategy.type == "env" and auth_strategy.name not in os.environ:
raise OSError(
f"Authentication environment variable {auth_strategy.name} not provided.\n"
f"Authentication environment variable {auth_strategy.name} is not set.\n"
f"An authentication token is required to use models from the {Provider.name} provider.\n"
f"Please specify it via `%env {auth_strategy.name}=token`. "
) from None
if auth_strategy.type == "multienv":
# Multiple environment variables must be set
missing_vars = [
var for var in auth_strategy.names if var not in os.environ
]
raise OSError(
f"Authentication environment variables {missing_vars} are not set.\n"
f"Multiple authentication tokens are required to use models from the {Provider.name} provider.\n"
f"Please specify them all via `%env` commands. "
) from None

# configure and instantiate provider
provider_params = {"model_id": local_model_id}
Expand Down
14 changes: 13 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
BedrockChat,
ChatAnthropic,
ChatOpenAI,
QianfanChatEndpoint,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms import (
Expand All @@ -34,6 +35,7 @@
HuggingFaceHub,
OpenAI,
OpenAIChat,
QianfanLLMEndpoint,
SagemakerEndpoint,
)
from langchain.llms.sagemaker_endpoint import LLMContentHandler
Expand All @@ -54,7 +56,7 @@ class EnvAuthStrategy(BaseModel):
class MultiEnvAuthStrategy(BaseModel):
"""Require multiple auth tokens via multiple environment variables."""

type: Literal["file"] = "file"
type: Literal["multienv"] = "multienv"
names: List[str]


Expand Down Expand Up @@ -775,3 +777,13 @@ async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
@property
def allows_concurrency(self):
return not "anthropic" in self.model_id


# Baidu QianfanChat provider. temporarily living as a separate class until
class QianfanProvider(BaseProvider, QianfanChatEndpoint):
id = "qianfan"
name = "ERNIE-Bot"
models = ["ERNIE-Bot", "ERNIE-Bot-4"]
model_id_key = "model_name"
pypi_package_deps = ["qianfan"]
auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"])
5 changes: 4 additions & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ all = [
"ipywidgets",
"pillow",
"openai",
"boto3"
"boto3",
"qianfan"
]

[project.entry-points."jupyter_ai.model_providers"]
Expand All @@ -67,13 +68,15 @@ sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider"
amazon-bedrock-chat = "jupyter_ai_magics:BedrockChatProvider"
qianfan = "jupyter_ai_magics:QianfanProvider"

[project.entry-points."jupyter_ai.embeddings_model_providers"]
bedrock = "jupyter_ai_magics:BedrockEmbeddingsProvider"
cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider"
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider"
qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider"

[tool.hatch.version]
source = "nodejs"
Expand Down
15 changes: 15 additions & 0 deletions packages/jupyter-ai/src/components/chat-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,27 @@ export function ChatSettings(): JSX.Element {
) {
newApiKeys[lmAuth.name] = '';
}
if (lmAuth?.type === 'multienv') {
lmAuth.names.forEach(apiKey => {
if (!server.config.api_keys.includes(apiKey)) {
newApiKeys[apiKey] = '';
}
});
}

if (
emAuth?.type === 'env' &&
!server.config.api_keys.includes(emAuth.name)
) {
newApiKeys[emAuth.name] = '';
}
if (emAuth?.type === 'multienv') {
emAuth.names.forEach(apiKey => {
if (!server.config.api_keys.includes(apiKey)) {
newApiKeys[apiKey] = '';
}
});
}

setApiKeys(newApiKeys);
}, [lmProvider, emProvider, server]);
Expand Down
11 changes: 10 additions & 1 deletion packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,16 @@ export namespace AiService {
type: 'aws';
};

export type AuthStrategy = EnvAuthStrategy | AwsAuthStrategy | null;
export type MultiEnvAuthStrategy = {
type: 'multienv';
names: string[];
};

export type AuthStrategy =
| AwsAuthStrategy
| EnvAuthStrategy
| MultiEnvAuthStrategy
| null;

export type TextField = {
type: 'text';
Expand Down
Loading