Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add vertex support #182

Merged
merged 21 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
- [Bedrock](#Bedrock)
- [SageMaker](#SageMaker)
- [Azure](#Azure)
- [Vertex](#Vertex)

## Examples (tl;dr)

Expand Down Expand Up @@ -736,4 +737,79 @@ response = client.chat.completions.create(
)
```

#### Async

```python
import asyncio
from ai21 import AsyncAI21AzureClient
from ai21.models.chat import ChatMessage

client = AsyncAI21AzureClient(
base_url="https://<YOUR-ENDPOINT>.inference.ai.azure.com/v1/chat/completions",
api_key="<your Azure api key>",
)

messages = [
ChatMessage(content="You are a helpful assistant", role="system"),
ChatMessage(content="What is the meaning of life?", role="user")
]

async def main():
response = await client.chat.completions.create(
model="jamba-instruct",
messages=messages,
)

asyncio.run(main())
```

### Vertex

If you wish to interact with your Vertex AI endpoint on GCP, use the `AI21VertexClient`
and `AsyncAI21VertexClient` clients.

The following models are supported on Vertex:

- `jamba-1.5-mini`
- `jamba-1.5-large`

```python
from ai21 import AI21VertexClient

from ai21.models.chat import ChatMessage

# You can also set the project_id, region, access_token and Google credentials in the constructor
client = AI21VertexClient()

messages = ChatMessage(content="What is the meaning of life?", role="user")

response = client.chat.completions.create(
model="jamba-1.5-mini",
messages=[messages],
)
```

#### Async

```python
import asyncio

from ai21 import AsyncAI21VertexClient
from ai21.models.chat import ChatMessage

# You can also set the project_id, region, access_token and Google credentials in the constructor
client = AsyncAI21VertexClient()


async def main():
messages = ChatMessage(content="What is the meaning of life?", role="user")

response = await client.chat.completions.create(
model="jamba-1.5-mini",
messages=[messages],
)

asyncio.run(main())
```

Happy prompting! 🚀
22 changes: 21 additions & 1 deletion ai21/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def _import_async_sagemaker_client():
return AsyncAI21SageMakerClient


def _import_vertex_client():
from ai21.clients.vertex.ai21_vertex_client import AI21VertexClient

return AI21VertexClient


def _import_async_vertex_client():
from ai21.clients.vertex.ai21_vertex_client import AsyncAI21VertexClient

return AsyncAI21VertexClient


def __getattr__(name: str) -> Any:
try:
if name == "AI21BedrockClient":
Expand All @@ -67,8 +79,14 @@ def __getattr__(name: str) -> Any:

if name == "AsyncAI21SageMakerClient":
return _import_async_sagemaker_client()

if name == "AI21VertexClient":
return _import_vertex_client()

if name == "AsyncAI21VertexClient":
return _import_async_vertex_client()
except ImportError as e:
raise ImportError(f'Please install "ai21[AWS]" in order to use {name}') from e
raise ImportError('Please install "ai21[AWS]" for SageMaker or Bedrock, or "ai21[Vertex]" for Vertex') from e


__all__ = [
Expand All @@ -89,4 +107,6 @@ def __getattr__(name: str) -> Any:
"AsyncAI21AzureClient",
"AsyncAI21BedrockClient",
"AsyncAI21SageMakerClient",
"AI21VertexClient",
"AsyncAI21VertexClient",
]
Empty file added ai21/clients/vertex/__init__.py
Empty file.
209 changes: 209 additions & 0 deletions ai21/clients/vertex/ai21_vertex_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from __future__ import annotations

from typing import Optional, Dict, Any

import httpx
from google.auth.credentials import Credentials as GCPCredentials

from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat
from ai21.clients.vertex.gcp_authorization import GCPAuthorization
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
from ai21.http_client.http_client import AI21HTTPClient
from ai21.models.request_options import RequestOptions

_DEFAULT_GCP_REGION = "us-central1"
_VERTEX_BASE_URL_FORMAT = "https://{region}-aiplatform.googleapis.com/v1"
_VERTEX_PATH_FORMAT = "/projects/{project_id}/locations/{region}/publishers/ai21/models/{model}:{endpoint}"


class BaseAI21VertexClient:
def __init__(
self,
region: Optional[str] = None,
project_id: Optional[str] = None,
access_token: Optional[str] = None,
credentials: Optional[GCPCredentials] = None,
):
if access_token is not None and project_id is None:
raise ValueError("Field project_id is required when setting access_token")
self._region = region or _DEFAULT_GCP_REGION
self._access_token = access_token
self._project_id = project_id
self._credentials = credentials
self._gcp_auth = GCPAuthorization()

def _get_base_url(self) -> str:
return _VERTEX_BASE_URL_FORMAT.format(region=self._region)

def _get_access_token(self) -> str:
if self._access_token is not None:
return self._access_token

if self._credentials is None:
self._credentials, self._project_id = self._gcp_auth.get_gcp_credentials(
project_id=self._project_id,
)

if self._credentials is None:
raise ValueError("Could not get credentials for GCP project")

self._gcp_auth.refresh_auth(self._credentials)

if self._credentials.token is None:
raise RuntimeError(f"Could not get access token for GCP project {self._project_id}")

return self._credentials.token

def _build_path(
self,
project_id: str,
region: str,
model: str,
endpoint: str,
) -> str:
return _VERTEX_PATH_FORMAT.format(
project_id=project_id,
region=region,
model=model,
endpoint=endpoint,
)

def _get_authorization_header(self) -> Dict[str, Any]:
access_token = self._get_access_token()
return {"Authorization": f"Bearer {access_token}"}


class AI21VertexClient(BaseAI21VertexClient, AI21HTTPClient):
def __init__(
self,
region: Optional[str] = None,
project_id: Optional[str] = None,
base_url: Optional[str] = None,
access_token: Optional[str] = None,
credentials: Optional[GCPCredentials] = None,
headers: Dict[str, str] | None = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
http_client: Optional[httpx.Client] = None,
):
BaseAI21VertexClient.__init__(
self,
region=region,
project_id=project_id,
access_token=access_token,
credentials=credentials,
)

if base_url is None:
base_url = self._get_base_url()

AI21HTTPClient.__init__(
self,
base_url=base_url,
timeout_sec=timeout_sec,
num_retries=num_retries,
headers=headers,
client=http_client,
requires_api_key=False,
)

self.chat = StudioChat(self)
# Override the chat.create method to match the completions endpoint,
# so it wouldn't get to the old J2 completion endpoint
self.chat.create = self.chat.completions.create

def _build_request(self, options: RequestOptions) -> httpx.Request:
options = self._prepare_options(options)

return super()._build_request(options)

def _prepare_options(self, options: RequestOptions) -> RequestOptions:
body = options.body

model = body.pop("model")
stream = body.pop("stream", False)
endpoint = "streamRawPredict" if stream else "rawPredict"
headers = self._prepare_headers()
path = self._build_path(
project_id=self._project_id,
region=self._region,
model=model,
endpoint=endpoint,
)

return options.replace(
body=body,
path=path,
headers=headers,
)

def _prepare_headers(self) -> Dict[str, Any]:
return self._get_authorization_header()


class AsyncAI21VertexClient(BaseAI21VertexClient, AsyncAI21HTTPClient):
def __init__(
self,
region: Optional[str] = None,
project_id: Optional[str] = None,
base_url: Optional[str] = None,
access_token: Optional[str] = None,
credentials: Optional[GCPCredentials] = None,
headers: Dict[str, str] | None = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
http_client: Optional[httpx.AsyncClient] = None,
):
BaseAI21VertexClient.__init__(
self,
region=region,
project_id=project_id,
access_token=access_token,
credentials=credentials,
)

if base_url is None:
base_url = self._get_base_url()

AsyncAI21HTTPClient.__init__(
self,
base_url=base_url,
timeout_sec=timeout_sec,
num_retries=num_retries,
headers=headers,
client=http_client,
requires_api_key=False,
)

self.chat = AsyncStudioChat(self)
# Override the chat.create method to match the completions endpoint,
# so it wouldn't get to the old J2 completion endpoint
self.chat.create = self.chat.completions.create

def _build_request(self, options: RequestOptions) -> httpx.Request:
options = self._prepare_options(options)

return super()._build_request(options)

def _prepare_options(self, options: RequestOptions) -> RequestOptions:
body = options.body

model = body.pop("model")
stream = body.pop("stream", False)
endpoint = "streamRawPredict" if stream else "rawPredict"
headers = self._prepare_headers()
path = self._build_path(
project_id=self._project_id,
region=self._region,
model=model,
endpoint=endpoint,
)

return options.replace(
body=body,
path=path,
headers=headers,
)

def _prepare_headers(self) -> Dict[str, Any]:
return self._get_authorization_header()
43 changes: 43 additions & 0 deletions ai21/clients/vertex/gcp_authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from typing import Optional, Tuple

import google.auth
from google.auth.credentials import Credentials
from google.auth.transport.requests import Request
from google.auth.exceptions import DefaultCredentialsError

from ai21.errors import CredentialsError


class GCPAuthorization:
def get_gcp_credentials(
self,
project_id: Optional[str] = None,
) -> Tuple[Credentials, str]:
try:
credentials, loaded_project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
except DefaultCredentialsError as e:
raise CredentialsError(provider_name="GCP", error_message=str(e))

if project_id is not None and project_id != loaded_project_id:
miri-bar marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Mismatch between credentials project id and 'project_id'")

project_id = project_id or loaded_project_id

if project_id is None:
raise ValueError("Could not get project_id for GCP project")

if not isinstance(project_id, str):
raise ValueError(f"Variable project_id must be a string, got {type(project_id)} instead")

return credentials, project_id

def _get_gcp_request(self) -> Request:
return Request()

def refresh_auth(self, credentials: Credentials) -> None:
request = self._get_gcp_request()
credentials.refresh(request)
6 changes: 6 additions & 0 deletions ai21/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def __init__(self, key: str):
super().__init__(message)


class CredentialsError(AI21Error):
def __init__(self, provider_name: str, error_message: str):
message = f"Could not get default {provider_name} credentials: {error_message}"
super().__init__(message)


class StreamingDecodeError(AI21Error):
def __init__(self, chunk: str, error_message: Optional[str] = None):
message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format."
Expand Down
Empty file added examples/vertex/__init__.py
Empty file.
Loading
Loading