Skip to content

Commit

Permalink
Add openai audio adapter (#220)
Browse files Browse the repository at this point in the history
Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
  • Loading branch information
shiyu22 authored Apr 17, 2023
1 parent 7df259c commit 88bc249
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 32 deletions.
107 changes: 78 additions & 29 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Iterator
from typing import Iterator, Any

import base64
from io import BytesIO
Expand All @@ -16,7 +16,8 @@
get_message_from_openai_answer,
get_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_openai_url
get_image_from_openai_url,
get_audio_text_from_openai_answer,
)
from gptcache.utils import import_pillow

Expand Down Expand Up @@ -66,6 +67,75 @@ def cache_data_convert(cache_data):
**kwargs
)


class Completion(openai.Completion):
"""Openai Completion Wrapper"""

@classmethod
def llm_handler(cls, *llm_args, **llm_kwargs):
return super().create(*llm_args, **llm_kwargs)

@staticmethod
def cache_data_convert(cache_data):
return construct_text_from_cache(cache_data)

@staticmethod
def update_cache_callback(llm_data, update_cache_func):
update_cache_func(get_text_from_openai_answer(llm_data))
return llm_data

@classmethod
def create(cls, *args, **kwargs):
return adapt(
cls.llm_handler,
cls.cache_data_convert,
cls.update_cache_callback,
*args,
**kwargs
)


class Audio(openai.Audio):
"""Openai Audio Wrapper"""
@classmethod
def transcribe(cls, model: str, file: Any, *args, **kwargs):
def llm_handler(*llm_args, **llm_kwargs):
try:
return openai.Audio.transcribe(*llm_args, **llm_kwargs)
except Exception as e:
raise CacheError("openai error") from e

def cache_data_convert(cache_data):
return construct_audio_text_from_cache(cache_data)

def update_cache_callback(llm_data, update_cache_func):
update_cache_func(Answer(get_audio_text_from_openai_answer(llm_data), AnswerType.STR))
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, model=model, file=file, *args, **kwargs
)

@classmethod
def translate(cls, model: str, file: bytes, *args, **kwargs):
def llm_handler(*llm_args, **llm_kwargs):
try:
return openai.Audio.translate(*llm_args, **llm_kwargs)
except Exception as e:
raise CacheError("openai error") from e

def cache_data_convert(cache_data):
return construct_audio_text_from_cache(cache_data)

def update_cache_callback(llm_data, update_cache_func):
update_cache_func(Answer(get_audio_text_from_openai_answer(llm_data), AnswerType.STR))
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, model=model, file=file, *args, **kwargs
)


class Image(openai.Image):
"""Openai Image Wrapper"""

Expand Down Expand Up @@ -143,33 +213,6 @@ def construct_stream_resp_from_cache(return_message):
]


class Completion(openai.Completion):
"""Openai Completion Wrapper"""

@classmethod
def llm_handler(cls, *llm_args, **llm_kwargs):
return super().create(*llm_args, **llm_kwargs)

@staticmethod
def cache_data_convert(cache_data):
return construct_text_from_cache(cache_data)

@staticmethod
def update_cache_callback(llm_data, update_cache_func):
update_cache_func(get_text_from_openai_answer(llm_data))
return llm_data

@classmethod
def create(cls, *args, **kwargs):
return adapt(
cls.llm_handler,
cls.cache_data_convert,
cls.update_cache_callback,
*args,
**kwargs
)


def construct_text_from_cache(return_text):
return {
"gptcache": True,
Expand Down Expand Up @@ -214,3 +257,9 @@ def construct_image_create_resp_from_cache(image_data, response_format, size):
{response_format: image_data}
]
}


def construct_audio_text_from_cache(return_text):
return {
"text": return_text,
}
6 changes: 6 additions & 0 deletions gptcache/processor/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@ def nop(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:

def get_prompt(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
return data.get("prompt")


def get_file_bytes(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
res = data.get("file").peek()
assert isinstance(res, bytes)
return res
4 changes: 4 additions & 0 deletions gptcache/utils/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ def get_image_from_path(openai_resp):
with open(img_path, "rb") as f:
img_data = base64.b64encode(f.read())
return img_data


def get_audio_text_from_openai_answer(openai_resp):
return openai_resp["text"]
63 changes: 60 additions & 3 deletions tests/unit_tests/adapter/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
get_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_path,
get_image_from_openai_url
get_image_from_openai_url,
get_audio_text_from_openai_answer
)
from gptcache.adapter import openai
from gptcache import cache
from gptcache.processor.pre import get_prompt
from gptcache.manager import get_data_manager
from gptcache.processor.pre import get_prompt, get_file_bytes

import os
import base64
import requests
from urllib.request import urlopen
from io import BytesIO
try:
from PIL import Image
Expand Down Expand Up @@ -199,3 +201,58 @@ def test_image_create():
img_returned = get_image_from_path(response)
assert img_returned == expected_img_data
os.remove(response["data"][0]["url"])


def test_audio_transcribe():
cache.init(pre_embedding_func=get_file_bytes)
url = "https://github.com/towhee-io/examples/releases/download/data/blues.00000.mp3"
audio_file = urlopen(url)
expect_answer = "One bourbon, one scotch and one bill Hey Mr. Bartender, come here I want another drink and I want it now My baby she gone, " \
"she been gone tonight I ain't seen my baby since night of her life One bourbon, one scotch and one bill"

with patch("openai.Audio.transcribe") as mock_create:
mock_create.return_value = {
"text": expect_answer
}

response = openai.Audio.transcribe(
model="whisper-1",
file=audio_file
)
answer_text = get_audio_text_from_openai_answer(response)
assert answer_text == expect_answer

response = openai.Audio.transcribe(
model="whisper-1",
file=audio_file
)
answer_text = get_audio_text_from_openai_answer(response)
assert answer_text == expect_answer


def test_audio_translate():
cache.init(pre_embedding_func=get_file_bytes,
data_manager=get_data_manager(data_path="data_map1.txt"))
url = "https://github.com/towhee-io/examples/releases/download/data/blues.00000.mp3"
audio_file = urlopen(url)
expect_answer = "One bourbon, one scotch and one bill Hey Mr. Bartender, come here I want another drink and I want it now My baby she gone, " \
"she been gone tonight I ain't seen my baby since night of her life One bourbon, one scotch and one bill"

with patch("openai.Audio.translate") as mock_create:
mock_create.return_value = {
"text": expect_answer
}

response = openai.Audio.translate(
model="whisper-1",
file=audio_file
)
answer_text = get_audio_text_from_openai_answer(response)
assert answer_text == expect_answer

response = openai.Audio.translate(
model="whisper-1",
file=audio_file
)
answer_text = get_audio_text_from_openai_answer(response)
assert answer_text == expect_answer

0 comments on commit 88bc249

Please sign in to comment.