Skip to content

Commit

Permalink
feat: Add cross-region inference support for Bedrock models (#535)
Browse files Browse the repository at this point in the history
This commit introduces the initial setup for supporting cross-region inference in the Bedrock Chat application. The changes include:

- Added `is_region_supported_for_inference` function in `utils.py` to check if a region supports inference.
- Modified `get_bedrock_client` function in `utils.py` to use cross-region inference if enabled and the region is supported.
- Updated `get_model_id` function in `bedrock.py` to use the base model ID for cross-region inference if enabled, the region is supported, and the model is included in `CROSS_REGION_INFERENCE_MODELS`. If any of these conditions are not met, it falls back to using the local model ID and logs a warning.
- Added `enableBedrockCrossRegionInference` option to `cdk.json` with the default value set to `false`.

These changes lay the foundation for enabling cross-region inference in the Bedrock Chat application. The feature can be enabled or disabled using the `enableBedrockCrossRegionInference` configuration option in `cdk.json`.
  • Loading branch information
chm10 committed Sep 17, 2024
1 parent a759f3d commit 7b14fd9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
55 changes: 34 additions & 21 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from app.repositories.models.conversation import MessageModel
from app.repositories.models.custom_bot import GenerationParamsModel
from app.routes.schemas.conversation import type_model_name
from app.utils import convert_dict_keys_to_camel_case, get_bedrock_client
from app.utils import convert_dict_keys_to_camel_case, get_bedrock_client, is_region_supported_for_inference, ENABLE_BEDROCK_CROSS_REGION_INFERENCE
from typing_extensions import NotRequired, TypedDict, no_type_check

logger = logging.getLogger(__name__)
Expand All @@ -23,8 +23,9 @@
if ENABLE_MISTRAL
else DEFAULT_CLAUDE_GENERATION_CONFIG
)
ENABLE_BEDROCK_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true"

client = get_bedrock_client()
client = get_bedrock_client(BEDROCK_REGION)


class ConverseApiToolSpec(TypedDict):
Expand Down Expand Up @@ -219,7 +220,7 @@ def compose_args_for_converse_api(


def call_converse_api(args: ConverseApiRequest) -> ConverseApiResponse:
client = get_bedrock_client()
client = get_bedrock_client(BEDROCK_REGION)
messages = args["messages"]
inference_config = args["inference_config"]
additional_model_request_fields = args["additional_model_request_fields"]
Expand Down Expand Up @@ -256,27 +257,39 @@ def calculate_price(

return input_price * input_tokens / 1000.0 + output_price * output_tokens / 1000.0

CROSS_REGION_INFERENCE_MODELS = {
"claude-v3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-v3-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-v3-opus": "anthropic.claude-3-opus-20240229-v1:0",
"claude-v3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0",
}

def get_model_id(model: type_model_name) -> str:
# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
if model == "claude-v2":
return "anthropic.claude-v2:1"
elif model == "claude-instant-v1":
return "anthropic.claude-instant-v1"
elif model == "claude-v3-sonnet":
return "anthropic.claude-3-sonnet-20240229-v1:0"
elif model == "claude-v3-haiku":
return "anthropic.claude-3-haiku-20240307-v1:0"
elif model == "claude-v3-opus":
return "anthropic.claude-3-opus-20240229-v1:0"
elif model == "claude-v3.5-sonnet":
return "anthropic.claude-3-5-sonnet-20240620-v1:0"
elif model == "mistral-7b-instruct":
return "mistral.mistral-7b-instruct-v0:2"
elif model == "mixtral-8x7b-instruct":
return "mistral.mixtral-8x7b-instruct-v0:1"
elif model == "mistral-large":
return "mistral.mistral-large-2402-v1:0"
base_model_id = {
"claude-v2": "anthropic.claude-v2:1",
"claude-instant-v1": "anthropic.claude-instant-v1",
"claude-v3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-v3-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-v3-opus": "anthropic.claude-3-opus-20240229-v1:0",
"claude-v3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2",
"mixtral-8x7b-instruct": "mistral.mixtral-8x7b-instruct-v0:1",
"mistral-large": "mistral.mistral-large-2402-v1:0",
}[model]

if (ENABLE_BEDROCK_CROSS_REGION_INFERENCE and
is_region_supported_for_inference(BEDROCK_REGION) and
model in CROSS_REGION_INFERENCE_MODELS):
logger.info(f"Using cross-region inference for model {model} in region {BEDROCK_REGION}")
return base_model_id
else:
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE:
if not is_region_supported_for_inference(BEDROCK_REGION):
logger.warning(f"Cross-region inference is enabled, but the region {BEDROCK_REGION} is not supported. Using local model.")
elif model not in CROSS_REGION_INFERENCE_MODELS:
logger.warning(f"Cross-region inference is not available for model {model}. Using local model.")
return f"{base_model_id}-local"


def calculate_query_embedding(question: str) -> list[float]:
Expand Down
15 changes: 12 additions & 3 deletions backend/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"PUBLISH_API_CODEBUILD_PROJECT_NAME", ""
)
DB_SECRETS_ARN = os.environ.get("DB_SECRETS_ARN", "")

ENABLE_BEDROCK_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true"

def snake_to_camel(snake_str):
components = snake_str.split("_")
Expand All @@ -40,9 +40,18 @@ def is_running_on_lambda():
return "AWS_EXECUTION_ENV" in os.environ


def is_region_supported_for_inference(region: str) -> bool:
supported_regions = ['us-east-1', 'us-west-2', 'eu-west-1', 'eu-west-3', 'eu-central-1'] # Add more as they become available
return region in supported_regions

def get_bedrock_client(region=BEDROCK_REGION):
client = boto3.client("bedrock-runtime", region)
return client
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE and is_region_supported_for_inference(region):
logger.info(f"Using cross-region Bedrock client for region {region}")
return boto3.client("bedrock-runtime", region_name=region)
else:
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE:
logger.warning(f"Cross-region inference is enabled, but the region {region} is not supported. Using default region.")
return boto3.client("bedrock-runtime", region_name=REGION


def get_bedrock_agent_client(region=REGION):
Expand Down
1 change: 1 addition & 0 deletions cdk/cdk.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"@aws-cdk/core:includePrefixInUniqueNameGeneration": true,
"@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true,
"enableMistral": false,
"enableBedrockCrossRegionInference": false,
"bedrockRegion": "us-east-1",
"allowedIpV4AddressRanges": ["0.0.0.0/1", "128.0.0.0/1"],
"allowedIpV6AddressRanges": [
Expand Down

0 comments on commit 7b14fd9

Please sign in to comment.