From 38f527ebfb34fb10c5f9c646ef7e94aea2e4cfb3 Mon Sep 17 00:00:00 2001 From: JGalego Date: Wed, 3 Jul 2024 17:33:33 +0100 Subject: [PATCH] Refactored to add support for Meta models provided by Amazon SageMaker --- dsp/modules/aws_models.py | 46 ++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py index ba2fa0848..820dfc67e 100644 --- a/dsp/modules/aws_models.py +++ b/dsp/modules/aws_models.py @@ -275,6 +275,8 @@ def __init__( self.aws_provider = aws_provider self.provider = aws_provider.get_provider_name() + self.kwargs["stop"] = ["<|eot_id|>"] + for k, v in kwargs.items(): self.kwargs[k] = v @@ -290,25 +292,43 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa for k, v in kwargs.items(): base_args[k] = v - n, query_args = self.aws_provider.sanitize_kwargs(base_args) + n, base_args = self.aws_provider.sanitize_kwargs(base_args) # Meta models do not support the following parameters - query_args.pop("frequency_penalty", None) - query_args.pop("num_generations", None) - query_args.pop("presence_penalty", None) - query_args.pop("model", None) + base_args.pop("frequency_penalty", None) + base_args.pop("num_generations", None) + base_args.pop("presence_penalty", None) + base_args.pop("model", None) - max_tokens = query_args.pop("max_tokens", None) - if max_tokens: - query_args["max_gen_len"] = max_tokens + max_tokens = base_args.pop("max_tokens", None) + + query_args: dict[str, str | float] = {} + if isinstance(self.aws_provider, Bedrock): + if max_tokens: + base_args["max_gen_len"] = max_tokens + query_args = base_args + query_args["prompt"] = prompt + elif isinstance(self.aws_provider, Sagemaker): + if max_tokens: + base_args["max_new_tokens"] = max_tokens + query_args["parameters"] = base_args + query_args["inputs"] = prompt + else: + raise ValueError("Error - provider not recognized") - query_args["prompt"] = prompt return (n, query_args) def _call_model(self, body: str) -> str: - response = self.aws_provider.predictor.invoke_model( - modelId=self._model_name, + response = self.aws_provider.call_model( + model_id=self._model_name, body=body, ) - response_body = json.loads(response["body"].read()) - return response_body["generation"] + if isinstance(self.aws_provider, Bedrock): + response_body = json.loads(response["body"].read()) + completion = response_body["generation"] + elif isinstance(self.aws_provider, Sagemaker): + response_body = json.loads(response["Body"].read()) + completion = response_body["generated_text"] + else: + raise ValueError("Error - provider not recognized") + return completion