From 7b14fd905377b394dd4af8376c3d468eaeac9220 Mon Sep 17 00:00:00 2001 From: chm10 Date: Tue, 17 Sep 2024 01:24:18 -0300 Subject: [PATCH] feat: Add cross-region inference support for Bedrock models (#535) This commit introduces the initial setup for supporting cross-region inference in the Bedrock Chat application. The changes include: - Added `is_region_supported_for_inference` function in `utils.py` to check if a region supports inference. - Modified `get_bedrock_client` function in `utils.py` to use cross-region inference if enabled and the region is supported. - Updated `get_model_id` function in `bedrock.py` to use the base model ID for cross-region inference if enabled, the region is supported, and the model is included in `CROSS_REGION_INFERENCE_MODELS`. If any of these conditions are not met, it falls back to using the local model ID and logs a warning. - Added `enableBedrockCrossRegionInference` option to `cdk.json` with the default value set to `false`. These changes lay the foundation for enabling cross-region inference in the Bedrock Chat application. The feature can be enabled or disabled using the `enableBedrockCrossRegionInference` configuration option in `cdk.json`. --- backend/app/bedrock.py | 55 ++++++++++++++++++++++++++---------------- backend/app/utils.py | 15 +++++++++--- cdk/cdk.json | 1 + 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/backend/app/bedrock.py b/backend/app/bedrock.py index 9c96ee80..6cce3e08 100644 --- a/backend/app/bedrock.py +++ b/backend/app/bedrock.py @@ -11,7 +11,7 @@ from app.repositories.models.conversation import MessageModel from app.repositories.models.custom_bot import GenerationParamsModel from app.routes.schemas.conversation import type_model_name -from app.utils import convert_dict_keys_to_camel_case, get_bedrock_client +from app.utils import convert_dict_keys_to_camel_case, get_bedrock_client, is_region_supported_for_inference, ENABLE_BEDROCK_CROSS_REGION_INFERENCE from typing_extensions import NotRequired, TypedDict, no_type_check logger = logging.getLogger(__name__) @@ -23,8 +23,9 @@ if ENABLE_MISTRAL else DEFAULT_CLAUDE_GENERATION_CONFIG ) +ENABLE_BEDROCK_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true" -client = get_bedrock_client() +client = get_bedrock_client(BEDROCK_REGION) class ConverseApiToolSpec(TypedDict): @@ -219,7 +220,7 @@ def compose_args_for_converse_api( def call_converse_api(args: ConverseApiRequest) -> ConverseApiResponse: - client = get_bedrock_client() + client = get_bedrock_client(BEDROCK_REGION) messages = args["messages"] inference_config = args["inference_config"] additional_model_request_fields = args["additional_model_request_fields"] @@ -256,27 +257,39 @@ def calculate_price( return input_price * input_tokens / 1000.0 + output_price * output_tokens / 1000.0 +CROSS_REGION_INFERENCE_MODELS = { + "claude-v3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-v3-haiku": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-v3-opus": "anthropic.claude-3-opus-20240229-v1:0", + "claude-v3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", +} def get_model_id(model: type_model_name) -> str: # Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html - if model == "claude-v2": - return "anthropic.claude-v2:1" - elif model == "claude-instant-v1": - return "anthropic.claude-instant-v1" - elif model == "claude-v3-sonnet": - return "anthropic.claude-3-sonnet-20240229-v1:0" - elif model == "claude-v3-haiku": - return "anthropic.claude-3-haiku-20240307-v1:0" - elif model == "claude-v3-opus": - return "anthropic.claude-3-opus-20240229-v1:0" - elif model == "claude-v3.5-sonnet": - return "anthropic.claude-3-5-sonnet-20240620-v1:0" - elif model == "mistral-7b-instruct": - return "mistral.mistral-7b-instruct-v0:2" - elif model == "mixtral-8x7b-instruct": - return "mistral.mixtral-8x7b-instruct-v0:1" - elif model == "mistral-large": - return "mistral.mistral-large-2402-v1:0" + base_model_id = { + "claude-v2": "anthropic.claude-v2:1", + "claude-instant-v1": "anthropic.claude-instant-v1", + "claude-v3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-v3-haiku": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-v3-opus": "anthropic.claude-3-opus-20240229-v1:0", + "claude-v3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2", + "mixtral-8x7b-instruct": "mistral.mixtral-8x7b-instruct-v0:1", + "mistral-large": "mistral.mistral-large-2402-v1:0", + }[model] + + if (ENABLE_BEDROCK_CROSS_REGION_INFERENCE and + is_region_supported_for_inference(BEDROCK_REGION) and + model in CROSS_REGION_INFERENCE_MODELS): + logger.info(f"Using cross-region inference for model {model} in region {BEDROCK_REGION}") + return base_model_id + else: + if ENABLE_BEDROCK_CROSS_REGION_INFERENCE: + if not is_region_supported_for_inference(BEDROCK_REGION): + logger.warning(f"Cross-region inference is enabled, but the region {BEDROCK_REGION} is not supported. Using local model.") + elif model not in CROSS_REGION_INFERENCE_MODELS: + logger.warning(f"Cross-region inference is not available for model {model}. Using local model.") + return f"{base_model_id}-local" def calculate_query_embedding(question: str) -> list[float]: diff --git a/backend/app/utils.py b/backend/app/utils.py index 58e78905..113b4b50 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -19,7 +19,7 @@ "PUBLISH_API_CODEBUILD_PROJECT_NAME", "" ) DB_SECRETS_ARN = os.environ.get("DB_SECRETS_ARN", "") - +ENABLE_BEDROCK_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true" def snake_to_camel(snake_str): components = snake_str.split("_") @@ -40,9 +40,18 @@ def is_running_on_lambda(): return "AWS_EXECUTION_ENV" in os.environ +def is_region_supported_for_inference(region: str) -> bool: + supported_regions = ['us-east-1', 'us-west-2', 'eu-west-1', 'eu-west-3', 'eu-central-1'] # Add more as they become available + return region in supported_regions + def get_bedrock_client(region=BEDROCK_REGION): - client = boto3.client("bedrock-runtime", region) - return client + if ENABLE_BEDROCK_CROSS_REGION_INFERENCE and is_region_supported_for_inference(region): + logger.info(f"Using cross-region Bedrock client for region {region}") + return boto3.client("bedrock-runtime", region_name=region) + else: + if ENABLE_BEDROCK_CROSS_REGION_INFERENCE: + logger.warning(f"Cross-region inference is enabled, but the region {region} is not supported. Using default region.") + return boto3.client("bedrock-runtime", region_name=REGION def get_bedrock_agent_client(region=REGION): diff --git a/cdk/cdk.json b/cdk/cdk.json index 38f354c6..0b4ca418 100644 --- a/cdk/cdk.json +++ b/cdk/cdk.json @@ -52,6 +52,7 @@ "@aws-cdk/core:includePrefixInUniqueNameGeneration": true, "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true, "enableMistral": false, + "enableBedrockCrossRegionInference": false, "bedrockRegion": "us-east-1", "allowedIpV4AddressRanges": ["0.0.0.0/1", "128.0.0.0/1"], "allowedIpV6AddressRanges": [