Skip to content

Commit

Permalink
Merge pull request #1241 from JGalego/fix/aws-meta-sagemaker
Browse files Browse the repository at this point in the history
fix: Support for Meta models provided by Amazon SageMaker
  • Loading branch information
arnavsinghvi11 committed Jul 8, 2024
2 parents 282d0a6 + 38f527e commit 46f0c73
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions dsp/modules/aws_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def __init__(
self.aws_provider = aws_provider
self.provider = aws_provider.get_provider_name()

self.kwargs["stop"] = ["<|eot_id|>"]

for k, v in kwargs.items():
self.kwargs[k] = v

Expand All @@ -290,25 +292,43 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa
for k, v in kwargs.items():
base_args[k] = v

n, query_args = self.aws_provider.sanitize_kwargs(base_args)
n, base_args = self.aws_provider.sanitize_kwargs(base_args)

# Meta models do not support the following parameters
query_args.pop("frequency_penalty", None)
query_args.pop("num_generations", None)
query_args.pop("presence_penalty", None)
query_args.pop("model", None)
base_args.pop("frequency_penalty", None)
base_args.pop("num_generations", None)
base_args.pop("presence_penalty", None)
base_args.pop("model", None)

max_tokens = query_args.pop("max_tokens", None)
if max_tokens:
query_args["max_gen_len"] = max_tokens
max_tokens = base_args.pop("max_tokens", None)

query_args: dict[str, str | float] = {}
if isinstance(self.aws_provider, Bedrock):
if max_tokens:
base_args["max_gen_len"] = max_tokens
query_args = base_args
query_args["prompt"] = prompt
elif isinstance(self.aws_provider, Sagemaker):
if max_tokens:
base_args["max_new_tokens"] = max_tokens
query_args["parameters"] = base_args
query_args["inputs"] = prompt
else:
raise ValueError("Error - provider not recognized")

query_args["prompt"] = prompt
return (n, query_args)

def _call_model(self, body: str) -> str:
response = self.aws_provider.predictor.invoke_model(
modelId=self._model_name,
response = self.aws_provider.call_model(
model_id=self._model_name,
body=body,
)
response_body = json.loads(response["body"].read())
return response_body["generation"]
if isinstance(self.aws_provider, Bedrock):
response_body = json.loads(response["body"].read())
completion = response_body["generation"]
elif isinstance(self.aws_provider, Sagemaker):
response_body = json.loads(response["Body"].read())
completion = response_body["generated_text"]
else:
raise ValueError("Error - provider not recognized")
return completion

0 comments on commit 46f0c73

Please sign in to comment.