Skip to content

Commit

Permalink
Add openai audio adapter
Browse files Browse the repository at this point in the history
Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
  • Loading branch information
shiyu22 committed Apr 17, 2023
1 parent f7587d2 commit b160fe4
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 31 deletions.
105 changes: 77 additions & 28 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
get_message_from_openai_answer,
get_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_openai_url
get_image_from_openai_url,
get_audio_text_from_openai_answer,
)
from gptcache.utils import import_pillow

Expand Down Expand Up @@ -66,6 +67,75 @@ def cache_data_convert(cache_data):
**kwargs
)


class Completion(openai.Completion):
"""Openai Completion Wrapper"""

@classmethod
def llm_handler(cls, *llm_args, **llm_kwargs):
return super().create(*llm_args, **llm_kwargs)

@staticmethod
def cache_data_convert(cache_data):
return construct_text_from_cache(cache_data)

@staticmethod
def update_cache_callback(llm_data, update_cache_func):
update_cache_func(get_text_from_openai_answer(llm_data))
return llm_data

@classmethod
def create(cls, *args, **kwargs):
return adapt(
cls.llm_handler,
cls.cache_data_convert,
cls.update_cache_callback,
*args,
**kwargs
)


class Audio(openai.Audio):
"""Openai Audio Wrapper"""
@classmethod
def transcribe(cls, *args, **kwargs):
def llm_handler(*llm_args, **llm_kwargs):
try:
return openai.Audio.transcribe(*llm_args, **llm_kwargs)
except Exception as e:
raise CacheError("openai error") from e

def cache_data_convert(cache_data):
return construct_audio_text_from_cache(cache_data)

def update_cache_callback(llm_data, update_cache_func):
update_cache_func(Answer(get_audio_text_from_openai_answer(llm_data), AnswerType.STR))
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)

@classmethod
def translate(cls, *args, **kwargs):
def llm_handler(*llm_args, **llm_kwargs):
try:
return openai.Audio.translate(*llm_args, **llm_kwargs)
except Exception as e:
raise CacheError("openai error") from e

def cache_data_convert(cache_data):
return construct_audio_text_from_cache(cache_data)

def update_cache_callback(llm_data, update_cache_func):
update_cache_func(Answer(get_audio_text_from_openai_answer(llm_data), AnswerType.STR))
return llm_data

return adapt(
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
)


class Image(openai.Image):
"""Openai Image Wrapper"""

Expand Down Expand Up @@ -143,33 +213,6 @@ def construct_stream_resp_from_cache(return_message):
]


class Completion(openai.Completion):
"""Openai Completion Wrapper"""

@classmethod
def llm_handler(cls, *llm_args, **llm_kwargs):
return super().create(*llm_args, **llm_kwargs)

@staticmethod
def cache_data_convert(cache_data):
return construct_text_from_cache(cache_data)

@staticmethod
def update_cache_callback(llm_data, update_cache_func):
update_cache_func(get_text_from_openai_answer(llm_data))
return llm_data

@classmethod
def create(cls, *args, **kwargs):
return adapt(
cls.llm_handler,
cls.cache_data_convert,
cls.update_cache_callback,
*args,
**kwargs
)


def construct_text_from_cache(return_text):
return {
"gptcache": True,
Expand Down Expand Up @@ -214,3 +257,9 @@ def construct_image_create_resp_from_cache(image_data, response_format, size):
{response_format: image_data}
]
}


def construct_audio_text_from_cache(return_text):
return {
"text": return_text,
}
4 changes: 4 additions & 0 deletions gptcache/processor/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ def nop(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:

def get_prompt(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
return data.get("prompt")


def get_file(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
return data.get("file")
4 changes: 4 additions & 0 deletions gptcache/utils/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ def get_image_from_path(openai_resp):
with open(img_path, "rb") as f:
img_data = base64.b64encode(f.read())
return img_data


def get_audio_text_from_openai_answer(openai_resp):
return openai_resp["text"]
63 changes: 60 additions & 3 deletions tests/unit_tests/adapter/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
get_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_path,
get_image_from_openai_url
get_image_from_openai_url,
get_audio_text_from_openai_answer
)
from gptcache.adapter import openai
from gptcache import cache
from gptcache.processor.pre import get_prompt
from gptcache.manager import get_data_manager
from gptcache.processor.pre import get_prompt, get_file

import os
import base64
import requests
from urllib.request import urlopen
from io import BytesIO
try:
from PIL import Image
Expand Down Expand Up @@ -199,3 +201,58 @@ def test_image_create():
img_returned = get_image_from_path(response)
assert img_returned == expected_img_data
os.remove(response["data"][0]["url"])


def test_audio_transcribe():
cache.init(pre_embedding_func=get_file)
url = "https://github.com/towhee-io/examples/releases/download/data/blues.00000.wav"
audio_file = urlopen(url).read()
expect_answer = "One bourbon, one scotch and one bill Hey Mr. Bartender, come here I want another drink and I want it now My baby she gone, " \
"she been gone tonight I ain't seen my baby since night of her life One bourbon, one scotch and one bill"

with patch("openai.Audio.transcribe") as mock_create:
mock_create.return_value = {
"text": expect_answer
}

response = openai.Audio.transcribe(
model="whisper-1",
file=audio_file
)
answer_text = get_audio_text_from_openai_url(response)
assert answer_text == expect_answer

response = openai.Audio.transcribe(
model="whisper-1",
file=audio_file
)
answer_text = get_audio_text_from_openai_url(response)
assert answer_text == expect_answer


def test_audio_translate():
cache.init(pre_embedding_func=get_file,
data_manager=get_data_manager(data_path="data_map1.txt"))
url = "https://github.com/towhee-io/examples/releases/download/data/blues.00000.wav"
audio_file = urlopen(url).read()
expect_answer = "One bourbon, one scotch and one bill Hey Mr. Bartender, come here I want another drink and I want it now My baby she gone, " \
"she been gone tonight I ain't seen my baby since night of her life One bourbon, one scotch and one bill"

with patch("openai.Audio.translate") as mock_create:
mock_create.return_value = {
"text": expect_answer
}

response = openai.Audio.translate(
model="whisper-1",
file=audio_file
)
answer_text = get_audio_text_from_openai_answer(response)
assert answer_text == expect_answer

response = openai.Audio.translate(
model="whisper-1",
file=audio_file
)
answer_text = get_audio_text_from_openai_answer(response)
assert answer_text == expect_answer

0 comments on commit b160fe4

Please sign in to comment.