Skip to content

Commit

Permalink
Support query cache for chatbot (#1148)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvliang-intel authored Jul 24, 2023
1 parent 60e172e commit 1b44631
Show file tree
Hide file tree
Showing 6 changed files with 574 additions and 65 deletions.
67 changes: 53 additions & 14 deletions workflows/chatbot/inference/backend/chat/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from utils import build_logger, server_error_msg

from sse_starlette.sse import EventSourceResponse

from starlette.responses import RedirectResponse

logger = build_logger("controller", "controller.log")

Expand Down Expand Up @@ -206,15 +206,19 @@ def worker_api_generate_stream(self, params):
try:
response = requests.post(worker_addr + "/worker_generate_stream",
json=params, stream=True, timeout=1000)
result = ""
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
print("chunk=======", chunk)
# yield chunk + b"\0"
a = chunk.decode("utf-8")
a = re.sub(r'\\u2019', "'", a)
a = re.sub(r'\\\\ufffd', '', a)
result += a
yield f"data: {a}\n\n"
# yield f"data: \n\n"
from ..llmcache.cache import put
put(params["prompt"], result)
yield f"data: [DONE]\n\n"
except requests.exceptions.RequestException as e:
logger.info(f"worker timeout: {worker_addr}")
Expand Down Expand Up @@ -277,13 +281,11 @@ async def register_worker(request: Request):
async def refresh_all_workers():
models = controller.refresh_all_workers()


@app.post("/list_models")
async def list_models():
models = controller.list_models()
return {"models": models}


@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
Expand All @@ -298,7 +300,6 @@ async def receive_heart_beat(request: Request):
data["worker_name"], data["queue_length"])
return {"exist": exist}


@app.post("/worker_generate_stream")
async def worker_api_generate_stream(request: Request):
logger.info('Received request: %s', await request.json())
Expand All @@ -307,16 +308,42 @@ async def worker_api_generate_stream(request: Request):
params = params["msgData"]
generator = controller.worker_api_generate_stream(params)
return StreamingResponse(generator, media_type="text/event-stream")
#return EventSourceResponse(generator)
"""
async def test_generator():
for i in range(256):
print(i)
yield str(i)
# time.sleep(0.5)
await asyncio.sleep(1)
return EventSourceResponse(test_generator())
"""

@app.post("/v1/models")
async def list_models():
models = controller.list_models()
return {"models": models}

@app.post("/v1/chat/completions")
async def worker_api_generate_stream(request: Request):
logger.info('Received request: %s', await request.json())
params = await request.json()
if "msgData" in params:
params = params["msgData"]
generator = controller.worker_api_generate_stream(params)
return StreamingResponse(generator, media_type="text/event-stream")

@app.post("/v1/chat/llmcache")
async def get_cache(request: Request):
logger.info('Received request: %s', await request.json())
params = await request.json()
if "msgData" in params:
params = params["msgData"]
prompt = params["prompt"]
from ..llmcache.cache import get
result = get(prompt)
print(result)
if(result == None):
print("cache miss >>>>>>>>>>>>>>>")
response = RedirectResponse(url="/v1/chat/completions")
return response
else:
print("cache hit >>>>>>>>>>>>>>>>")
def stream_results():
yield "data: Response from Cache: {}\n\n".format(result['choices'][0]['text'])
yield "data: [DONE]\n\n"

return StreamingResponse(stream_results(), media_type="text/event-stream")

STREAM_DELAY = 1 # second
RETRY_TIMEOUT = 15000 # milisecond
Expand Down Expand Up @@ -391,8 +418,20 @@ async def event_stream():
parser.add_argument("--port", type=int, default=80)
parser.add_argument("--dispatch-method", type=str, choices=[
"lottery", "shortest_queue"], default="shortest_queue")
parser.add_argument(
"--cache-chat-config-file", default="cache_config.yml", help="the cache config file"
)
parser.add_argument(
"--cache-embedding-model-dir", default="./instructor-large", help="the cache embedding model directory"
)
args = parser.parse_args()
logger.info(f"args: {args}")

from ..llmcache.cache import init_similar_cache_from_config, put
if args.cache_chat_config_file:
init_similar_cache_from_config(config_dir=args.cache_chat_config_file,
embedding_model_dir=args.cache_embedding_model_dir)
put("test","test")

controller = Controller(args.dispatch_method)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
Loading

0 comments on commit 1b44631

Please sign in to comment.